1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file is a part of the ORC runtime support library.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H
14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H
15
16 #include "c_api.h"
17 #include "common.h"
18 #include "error.h"
19 #include "simple_packed_serialization.h"
20 #include <type_traits>
21
22 namespace __orc_rt {
23
24 /// C++ wrapper function result: Same as CWrapperFunctionResult but
25 /// auto-releases memory.
26 class WrapperFunctionResult {
27 public:
28 /// Create a default WrapperFunctionResult.
WrapperFunctionResult()29 WrapperFunctionResult() { __orc_rt_CWrapperFunctionResultInit(&R); }
30
31 /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This
32 /// instance takes ownership of the result object and will automatically
33 /// call dispose on the result upon destruction.
WrapperFunctionResult(__orc_rt_CWrapperFunctionResult R)34 WrapperFunctionResult(__orc_rt_CWrapperFunctionResult R) : R(R) {}
35
36 WrapperFunctionResult(const WrapperFunctionResult &) = delete;
37 WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
38
WrapperFunctionResult(WrapperFunctionResult && Other)39 WrapperFunctionResult(WrapperFunctionResult &&Other) {
40 __orc_rt_CWrapperFunctionResultInit(&R);
41 std::swap(R, Other.R);
42 }
43
44 WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
45 __orc_rt_CWrapperFunctionResult Tmp;
46 __orc_rt_CWrapperFunctionResultInit(&Tmp);
47 std::swap(Tmp, Other.R);
48 std::swap(R, Tmp);
49 return *this;
50 }
51
~WrapperFunctionResult()52 ~WrapperFunctionResult() { __orc_rt_DisposeCWrapperFunctionResult(&R); }
53
54 /// Relinquish ownership of and return the
55 /// __orc_rt_CWrapperFunctionResult.
release()56 __orc_rt_CWrapperFunctionResult release() {
57 __orc_rt_CWrapperFunctionResult Tmp;
58 __orc_rt_CWrapperFunctionResultInit(&Tmp);
59 std::swap(R, Tmp);
60 return Tmp;
61 }
62
63 /// Get a pointer to the data contained in this instance.
data()64 const char *data() const { return __orc_rt_CWrapperFunctionResultData(&R); }
65
66 /// Returns the size of the data contained in this instance.
size()67 size_t size() const { return __orc_rt_CWrapperFunctionResultSize(&R); }
68
69 /// Returns true if this value is equivalent to a default-constructed
70 /// WrapperFunctionResult.
empty()71 bool empty() const { return __orc_rt_CWrapperFunctionResultEmpty(&R); }
72
73 /// Create a WrapperFunctionResult with the given size and return a pointer
74 /// to the underlying memory.
allocate(WrapperFunctionResult & R,size_t Size)75 static char *allocate(WrapperFunctionResult &R, size_t Size) {
76 __orc_rt_DisposeCWrapperFunctionResult(&R.R);
77 __orc_rt_CWrapperFunctionResultInit(&R.R);
78 return __orc_rt_CWrapperFunctionResultAllocate(&R.R, Size);
79 }
80
81 /// Copy from the given char range.
copyFrom(const char * Source,size_t Size)82 static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
83 return __orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size);
84 }
85
86 /// Copy from the given null-terminated string (includes the null-terminator).
copyFrom(const char * Source)87 static WrapperFunctionResult copyFrom(const char *Source) {
88 return __orc_rt_CreateCWrapperFunctionResultFromString(Source);
89 }
90
91 /// Copy from the given std::string (includes the null terminator).
copyFrom(const std::string & Source)92 static WrapperFunctionResult copyFrom(const std::string &Source) {
93 return copyFrom(Source.c_str());
94 }
95
96 /// Create an out-of-band error by copying the given string.
createOutOfBandError(const char * Msg)97 static WrapperFunctionResult createOutOfBandError(const char *Msg) {
98 return __orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg);
99 }
100
101 /// Create an out-of-band error by copying the given string.
createOutOfBandError(const std::string & Msg)102 static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
103 return createOutOfBandError(Msg.c_str());
104 }
105
106 /// If this value is an out-of-band error then this returns the error message,
107 /// otherwise returns nullptr.
getOutOfBandError()108 const char *getOutOfBandError() const {
109 return __orc_rt_CWrapperFunctionResultGetOutOfBandError(&R);
110 }
111
112 private:
113 __orc_rt_CWrapperFunctionResult R;
114 };
115
116 namespace detail {
117
118 template <typename SPSArgListT, typename... ArgTs>
119 Expected<WrapperFunctionResult>
serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args)120 serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
121 WrapperFunctionResult Result;
122 char *DataPtr =
123 WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...));
124 SPSOutputBuffer OB(DataPtr, Result.size());
125 if (!SPSArgListT::serialize(OB, Args...))
126 return make_error<StringError>(
127 "Error serializing arguments to blob in call");
128 return std::move(Result);
129 }
130
131 template <typename RetT> class WrapperFunctionHandlerCaller {
132 public:
133 template <typename HandlerT, typename ArgTupleT, std::size_t... I>
decltype(auto)134 static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
135 std::index_sequence<I...>) {
136 return std::forward<HandlerT>(H)(std::get<I>(Args)...);
137 }
138 };
139
140 template <> class WrapperFunctionHandlerCaller<void> {
141 public:
142 template <typename HandlerT, typename ArgTupleT, std::size_t... I>
call(HandlerT && H,ArgTupleT & Args,std::index_sequence<I...>)143 static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
144 std::index_sequence<I...>) {
145 std::forward<HandlerT>(H)(std::get<I>(Args)...);
146 return SPSEmpty();
147 }
148 };
149
150 template <typename WrapperFunctionImplT,
151 template <typename> class ResultSerializer, typename... SPSTagTs>
152 class WrapperFunctionHandlerHelper
153 : public WrapperFunctionHandlerHelper<
154 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
155 ResultSerializer, SPSTagTs...> {};
156
157 template <typename RetT, typename... ArgTs,
158 template <typename> class ResultSerializer, typename... SPSTagTs>
159 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
160 SPSTagTs...> {
161 public:
162 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
163 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
164
165 template <typename HandlerT>
apply(HandlerT && H,const char * ArgData,size_t ArgSize)166 static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
167 size_t ArgSize) {
168 ArgTuple Args;
169 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
170 return WrapperFunctionResult::createOutOfBandError(
171 "Could not deserialize arguments for wrapper function call");
172
173 auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
174 std::forward<HandlerT>(H), Args, ArgIndices{});
175
176 if (auto Result = ResultSerializer<decltype(HandlerResult)>::serialize(
177 std::move(HandlerResult)))
178 return std::move(*Result);
179 else
180 return WrapperFunctionResult::createOutOfBandError(
181 toString(Result.takeError()));
182 }
183
184 private:
185 template <std::size_t... I>
deserialize(const char * ArgData,size_t ArgSize,ArgTuple & Args,std::index_sequence<I...>)186 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
187 std::index_sequence<I...>) {
188 SPSInputBuffer IB(ArgData, ArgSize);
189 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
190 }
191
192 };
193
194 // Map function references to function types.
195 template <typename RetT, typename... ArgTs,
196 template <typename> class ResultSerializer, typename... SPSTagTs>
197 class WrapperFunctionHandlerHelper<RetT (&)(ArgTs...), ResultSerializer,
198 SPSTagTs...>
199 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
200 SPSTagTs...> {};
201
202 // Map non-const member function types to function types.
203 template <typename ClassT, typename RetT, typename... ArgTs,
204 template <typename> class ResultSerializer, typename... SPSTagTs>
205 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
206 SPSTagTs...>
207 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
208 SPSTagTs...> {};
209
210 // Map const member function types to function types.
211 template <typename ClassT, typename RetT, typename... ArgTs,
212 template <typename> class ResultSerializer, typename... SPSTagTs>
213 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
214 ResultSerializer, SPSTagTs...>
215 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
216 SPSTagTs...> {};
217
218 template <typename SPSRetTagT, typename RetT> class ResultSerializer {
219 public:
serialize(RetT Result)220 static Expected<WrapperFunctionResult> serialize(RetT Result) {
221 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
222 Result);
223 }
224 };
225
226 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
227 public:
serialize(Error Err)228 static Expected<WrapperFunctionResult> serialize(Error Err) {
229 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
230 toSPSSerializable(std::move(Err)));
231 }
232 };
233
234 template <typename SPSRetTagT, typename T>
235 class ResultSerializer<SPSRetTagT, Expected<T>> {
236 public:
serialize(Expected<T> E)237 static Expected<WrapperFunctionResult> serialize(Expected<T> E) {
238 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
239 toSPSSerializable(std::move(E)));
240 }
241 };
242
243 template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
244 public:
makeSafe(RetT & Result)245 static void makeSafe(RetT &Result) {}
246
deserialize(RetT & Result,const char * ArgData,size_t ArgSize)247 static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
248 SPSInputBuffer IB(ArgData, ArgSize);
249 if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
250 return make_error<StringError>(
251 "Error deserializing return value from blob in call");
252 return Error::success();
253 }
254 };
255
256 template <> class ResultDeserializer<SPSError, Error> {
257 public:
makeSafe(Error & Err)258 static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
259
deserialize(Error & Err,const char * ArgData,size_t ArgSize)260 static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
261 SPSInputBuffer IB(ArgData, ArgSize);
262 SPSSerializableError BSE;
263 if (!SPSArgList<SPSError>::deserialize(IB, BSE))
264 return make_error<StringError>(
265 "Error deserializing return value from blob in call");
266 Err = fromSPSSerializable(std::move(BSE));
267 return Error::success();
268 }
269 };
270
271 template <typename SPSTagT, typename T>
272 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
273 public:
makeSafe(Expected<T> & E)274 static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
275
deserialize(Expected<T> & E,const char * ArgData,size_t ArgSize)276 static Error deserialize(Expected<T> &E, const char *ArgData,
277 size_t ArgSize) {
278 SPSInputBuffer IB(ArgData, ArgSize);
279 SPSSerializableExpected<T> BSE;
280 if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
281 return make_error<StringError>(
282 "Error deserializing return value from blob in call");
283 E = fromSPSSerializable(std::move(BSE));
284 return Error::success();
285 }
286 };
287
288 } // end namespace detail
289
290 template <typename SPSSignature> class WrapperFunction;
291
292 template <typename SPSRetTagT, typename... SPSTagTs>
293 class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
294 private:
295 template <typename RetT>
296 using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
297
298 public:
299 template <typename RetT, typename... ArgTs>
call(const void * FnTag,RetT & Result,const ArgTs &...Args)300 static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) {
301
302 // RetT might be an Error or Expected value. Set the checked flag now:
303 // we don't want the user to have to check the unused result if this
304 // operation fails.
305 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
306
307 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx))
308 return make_error<StringError>("__orc_rt_jit_dispatch_ctx not set");
309 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch))
310 return make_error<StringError>("__orc_rt_jit_dispatch not set");
311
312 auto ArgBuffer =
313 detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
314 Args...);
315 if (!ArgBuffer)
316 return ArgBuffer.takeError();
317
318 WrapperFunctionResult ResultBuffer =
319 __orc_rt_jit_dispatch(&__orc_rt_jit_dispatch_ctx, FnTag,
320 ArgBuffer->data(), ArgBuffer->size());
321 if (auto ErrMsg = ResultBuffer.getOutOfBandError())
322 return make_error<StringError>(ErrMsg);
323
324 return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
325 Result, ResultBuffer.data(), ResultBuffer.size());
326 }
327
328 template <typename HandlerT>
handle(const char * ArgData,size_t ArgSize,HandlerT && Handler)329 static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
330 HandlerT &&Handler) {
331 using WFHH =
332 detail::WrapperFunctionHandlerHelper<HandlerT, ResultSerializer,
333 SPSTagTs...>;
334 return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
335 }
336
337 private:
makeSerializable(const T & Value)338 template <typename T> static const T &makeSerializable(const T &Value) {
339 return Value;
340 }
341
makeSerializable(Error Err)342 static detail::SPSSerializableError makeSerializable(Error Err) {
343 return detail::toSPSSerializable(std::move(Err));
344 }
345
346 template <typename T>
makeSerializable(Expected<T> E)347 static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
348 return detail::toSPSSerializable(std::move(E));
349 }
350 };
351
352 template <typename... SPSTagTs>
353 class WrapperFunction<void(SPSTagTs...)>
354 : private WrapperFunction<SPSEmpty(SPSTagTs...)> {
355 public:
356 template <typename... ArgTs>
call(const void * FnTag,const ArgTs &...Args)357 static Error call(const void *FnTag, const ArgTs &...Args) {
358 SPSEmpty BE;
359 return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...);
360 }
361
362 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
363 };
364
365 } // end namespace __orc_rt
366
367 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H
368