1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file is a part of the ORC runtime support library. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H 14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H 15 16 #include "c_api.h" 17 #include "common.h" 18 #include "error.h" 19 #include "simple_packed_serialization.h" 20 #include <type_traits> 21 22 namespace __orc_rt { 23 24 /// C++ wrapper function result: Same as CWrapperFunctionResult but 25 /// auto-releases memory. 26 class WrapperFunctionResult { 27 public: 28 /// Create a default WrapperFunctionResult. 29 WrapperFunctionResult() { __orc_rt_CWrapperFunctionResultInit(&R); } 30 31 /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This 32 /// instance takes ownership of the result object and will automatically 33 /// call dispose on the result upon destruction. 34 WrapperFunctionResult(__orc_rt_CWrapperFunctionResult R) : R(R) {} 35 36 WrapperFunctionResult(const WrapperFunctionResult &) = delete; 37 WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; 38 39 WrapperFunctionResult(WrapperFunctionResult &&Other) { 40 __orc_rt_CWrapperFunctionResultInit(&R); 41 std::swap(R, Other.R); 42 } 43 44 WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { 45 __orc_rt_CWrapperFunctionResult Tmp; 46 __orc_rt_CWrapperFunctionResultInit(&Tmp); 47 std::swap(Tmp, Other.R); 48 std::swap(R, Tmp); 49 return *this; 50 } 51 52 ~WrapperFunctionResult() { __orc_rt_DisposeCWrapperFunctionResult(&R); } 53 54 /// Relinquish ownership of and return the 55 /// __orc_rt_CWrapperFunctionResult. 56 __orc_rt_CWrapperFunctionResult release() { 57 __orc_rt_CWrapperFunctionResult Tmp; 58 __orc_rt_CWrapperFunctionResultInit(&Tmp); 59 std::swap(R, Tmp); 60 return Tmp; 61 } 62 63 /// Get a pointer to the data contained in this instance. 64 char *data() { return __orc_rt_CWrapperFunctionResultData(&R); } 65 66 /// Returns the size of the data contained in this instance. 67 size_t size() const { return __orc_rt_CWrapperFunctionResultSize(&R); } 68 69 /// Returns true if this value is equivalent to a default-constructed 70 /// WrapperFunctionResult. 71 bool empty() const { return __orc_rt_CWrapperFunctionResultEmpty(&R); } 72 73 /// Create a WrapperFunctionResult with the given size and return a pointer 74 /// to the underlying memory. 75 static WrapperFunctionResult allocate(size_t Size) { 76 WrapperFunctionResult R; 77 R.R = __orc_rt_CWrapperFunctionResultAllocate(Size); 78 return R; 79 } 80 81 /// Copy from the given char range. 82 static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { 83 return __orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size); 84 } 85 86 /// Copy from the given null-terminated string (includes the null-terminator). 87 static WrapperFunctionResult copyFrom(const char *Source) { 88 return __orc_rt_CreateCWrapperFunctionResultFromString(Source); 89 } 90 91 /// Copy from the given std::string (includes the null terminator). 92 static WrapperFunctionResult copyFrom(const std::string &Source) { 93 return copyFrom(Source.c_str()); 94 } 95 96 /// Create an out-of-band error by copying the given string. 97 static WrapperFunctionResult createOutOfBandError(const char *Msg) { 98 return __orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg); 99 } 100 101 /// Create an out-of-band error by copying the given string. 102 static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { 103 return createOutOfBandError(Msg.c_str()); 104 } 105 106 /// If this value is an out-of-band error then this returns the error message, 107 /// otherwise returns nullptr. 108 const char *getOutOfBandError() const { 109 return __orc_rt_CWrapperFunctionResultGetOutOfBandError(&R); 110 } 111 112 private: 113 __orc_rt_CWrapperFunctionResult R; 114 }; 115 116 namespace detail { 117 118 template <typename SPSArgListT, typename... ArgTs> 119 Expected<WrapperFunctionResult> 120 serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) { 121 auto Result = WrapperFunctionResult::allocate(SPSArgListT::size(Args...)); 122 SPSOutputBuffer OB(Result.data(), Result.size()); 123 if (!SPSArgListT::serialize(OB, Args...)) 124 return make_error<StringError>( 125 "Error serializing arguments to blob in call"); 126 return std::move(Result); 127 } 128 129 template <typename RetT> class WrapperFunctionHandlerCaller { 130 public: 131 template <typename HandlerT, typename ArgTupleT, std::size_t... I> 132 static decltype(auto) call(HandlerT &&H, ArgTupleT &Args, 133 std::index_sequence<I...>) { 134 return std::forward<HandlerT>(H)(std::get<I>(Args)...); 135 } 136 }; 137 138 template <> class WrapperFunctionHandlerCaller<void> { 139 public: 140 template <typename HandlerT, typename ArgTupleT, std::size_t... I> 141 static SPSEmpty call(HandlerT &&H, ArgTupleT &Args, 142 std::index_sequence<I...>) { 143 std::forward<HandlerT>(H)(std::get<I>(Args)...); 144 return SPSEmpty(); 145 } 146 }; 147 148 template <typename WrapperFunctionImplT, 149 template <typename> class ResultSerializer, typename... SPSTagTs> 150 class WrapperFunctionHandlerHelper 151 : public WrapperFunctionHandlerHelper< 152 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()), 153 ResultSerializer, SPSTagTs...> {}; 154 155 template <typename RetT, typename... ArgTs, 156 template <typename> class ResultSerializer, typename... SPSTagTs> 157 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 158 SPSTagTs...> { 159 public: 160 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>; 161 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>; 162 163 template <typename HandlerT> 164 static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, 165 size_t ArgSize) { 166 ArgTuple Args; 167 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) 168 return WrapperFunctionResult::createOutOfBandError( 169 "Could not deserialize arguments for wrapper function call"); 170 171 auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call( 172 std::forward<HandlerT>(H), Args, ArgIndices{}); 173 174 if (auto Result = ResultSerializer<decltype(HandlerResult)>::serialize( 175 std::move(HandlerResult))) 176 return std::move(*Result); 177 else 178 return WrapperFunctionResult::createOutOfBandError( 179 toString(Result.takeError())); 180 } 181 182 private: 183 template <std::size_t... I> 184 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, 185 std::index_sequence<I...>) { 186 SPSInputBuffer IB(ArgData, ArgSize); 187 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); 188 } 189 190 }; 191 192 // Map function references to function types. 193 template <typename RetT, typename... ArgTs, 194 template <typename> class ResultSerializer, typename... SPSTagTs> 195 class WrapperFunctionHandlerHelper<RetT (&)(ArgTs...), ResultSerializer, 196 SPSTagTs...> 197 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 198 SPSTagTs...> {}; 199 200 // Map non-const member function types to function types. 201 template <typename ClassT, typename RetT, typename... ArgTs, 202 template <typename> class ResultSerializer, typename... SPSTagTs> 203 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer, 204 SPSTagTs...> 205 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 206 SPSTagTs...> {}; 207 208 // Map const member function types to function types. 209 template <typename ClassT, typename RetT, typename... ArgTs, 210 template <typename> class ResultSerializer, typename... SPSTagTs> 211 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const, 212 ResultSerializer, SPSTagTs...> 213 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 214 SPSTagTs...> {}; 215 216 template <typename SPSRetTagT, typename RetT> class ResultSerializer { 217 public: 218 static Expected<WrapperFunctionResult> serialize(RetT Result) { 219 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( 220 Result); 221 } 222 }; 223 224 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> { 225 public: 226 static Expected<WrapperFunctionResult> serialize(Error Err) { 227 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( 228 toSPSSerializable(std::move(Err))); 229 } 230 }; 231 232 template <typename SPSRetTagT, typename T> 233 class ResultSerializer<SPSRetTagT, Expected<T>> { 234 public: 235 static Expected<WrapperFunctionResult> serialize(Expected<T> E) { 236 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( 237 toSPSSerializable(std::move(E))); 238 } 239 }; 240 241 template <typename SPSRetTagT, typename RetT> class ResultDeserializer { 242 public: 243 static void makeSafe(RetT &Result) {} 244 245 static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { 246 SPSInputBuffer IB(ArgData, ArgSize); 247 if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result)) 248 return make_error<StringError>( 249 "Error deserializing return value from blob in call"); 250 return Error::success(); 251 } 252 }; 253 254 template <> class ResultDeserializer<SPSError, Error> { 255 public: 256 static void makeSafe(Error &Err) { cantFail(std::move(Err)); } 257 258 static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { 259 SPSInputBuffer IB(ArgData, ArgSize); 260 SPSSerializableError BSE; 261 if (!SPSArgList<SPSError>::deserialize(IB, BSE)) 262 return make_error<StringError>( 263 "Error deserializing return value from blob in call"); 264 Err = fromSPSSerializable(std::move(BSE)); 265 return Error::success(); 266 } 267 }; 268 269 template <typename SPSTagT, typename T> 270 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> { 271 public: 272 static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); } 273 274 static Error deserialize(Expected<T> &E, const char *ArgData, 275 size_t ArgSize) { 276 SPSInputBuffer IB(ArgData, ArgSize); 277 SPSSerializableExpected<T> BSE; 278 if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE)) 279 return make_error<StringError>( 280 "Error deserializing return value from blob in call"); 281 E = fromSPSSerializable(std::move(BSE)); 282 return Error::success(); 283 } 284 }; 285 286 } // end namespace detail 287 288 template <typename SPSSignature> class WrapperFunction; 289 290 template <typename SPSRetTagT, typename... SPSTagTs> 291 class WrapperFunction<SPSRetTagT(SPSTagTs...)> { 292 private: 293 template <typename RetT> 294 using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>; 295 296 public: 297 template <typename RetT, typename... ArgTs> 298 static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) { 299 300 // RetT might be an Error or Expected value. Set the checked flag now: 301 // we don't want the user to have to check the unused result if this 302 // operation fails. 303 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result); 304 305 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx)) 306 return make_error<StringError>("__orc_rt_jit_dispatch_ctx not set"); 307 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch)) 308 return make_error<StringError>("__orc_rt_jit_dispatch not set"); 309 310 auto ArgBuffer = 311 detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>( 312 Args...); 313 if (!ArgBuffer) 314 return ArgBuffer.takeError(); 315 316 WrapperFunctionResult ResultBuffer = 317 __orc_rt_jit_dispatch(&__orc_rt_jit_dispatch_ctx, FnTag, 318 ArgBuffer->data(), ArgBuffer->size()); 319 if (auto ErrMsg = ResultBuffer.getOutOfBandError()) 320 return make_error<StringError>(ErrMsg); 321 322 return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize( 323 Result, ResultBuffer.data(), ResultBuffer.size()); 324 } 325 326 template <typename HandlerT> 327 static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, 328 HandlerT &&Handler) { 329 using WFHH = 330 detail::WrapperFunctionHandlerHelper<HandlerT, ResultSerializer, 331 SPSTagTs...>; 332 return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize); 333 } 334 335 private: 336 template <typename T> static const T &makeSerializable(const T &Value) { 337 return Value; 338 } 339 340 static detail::SPSSerializableError makeSerializable(Error Err) { 341 return detail::toSPSSerializable(std::move(Err)); 342 } 343 344 template <typename T> 345 static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) { 346 return detail::toSPSSerializable(std::move(E)); 347 } 348 }; 349 350 template <typename... SPSTagTs> 351 class WrapperFunction<void(SPSTagTs...)> 352 : private WrapperFunction<SPSEmpty(SPSTagTs...)> { 353 public: 354 template <typename... ArgTs> 355 static Error call(const void *FnTag, const ArgTs &...Args) { 356 SPSEmpty BE; 357 return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...); 358 } 359 360 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle; 361 }; 362 363 } // end namespace __orc_rt 364 365 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H 366