1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file is a part of the ORC runtime support library.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H
14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H
15 
16 #include "c_api.h"
17 #include "common.h"
18 #include "error.h"
19 #include "executor_address.h"
20 #include "simple_packed_serialization.h"
21 #include <type_traits>
22 
23 namespace __orc_rt {
24 
25 /// C++ wrapper function result: Same as CWrapperFunctionResult but
26 /// auto-releases memory.
27 class WrapperFunctionResult {
28 public:
29   /// Create a default WrapperFunctionResult.
30   WrapperFunctionResult() { __orc_rt_CWrapperFunctionResultInit(&R); }
31 
32   /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This
33   /// instance takes ownership of the result object and will automatically
34   /// call dispose on the result upon destruction.
35   WrapperFunctionResult(__orc_rt_CWrapperFunctionResult R) : R(R) {}
36 
37   WrapperFunctionResult(const WrapperFunctionResult &) = delete;
38   WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
39 
40   WrapperFunctionResult(WrapperFunctionResult &&Other) {
41     __orc_rt_CWrapperFunctionResultInit(&R);
42     std::swap(R, Other.R);
43   }
44 
45   WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
46     __orc_rt_CWrapperFunctionResult Tmp;
47     __orc_rt_CWrapperFunctionResultInit(&Tmp);
48     std::swap(Tmp, Other.R);
49     std::swap(R, Tmp);
50     return *this;
51   }
52 
53   ~WrapperFunctionResult() { __orc_rt_DisposeCWrapperFunctionResult(&R); }
54 
55   /// Relinquish ownership of and return the
56   /// __orc_rt_CWrapperFunctionResult.
57   __orc_rt_CWrapperFunctionResult release() {
58     __orc_rt_CWrapperFunctionResult Tmp;
59     __orc_rt_CWrapperFunctionResultInit(&Tmp);
60     std::swap(R, Tmp);
61     return Tmp;
62   }
63 
64   /// Get a pointer to the data contained in this instance.
65   char *data() { return __orc_rt_CWrapperFunctionResultData(&R); }
66 
67   /// Returns the size of the data contained in this instance.
68   size_t size() const { return __orc_rt_CWrapperFunctionResultSize(&R); }
69 
70   /// Returns true if this value is equivalent to a default-constructed
71   /// WrapperFunctionResult.
72   bool empty() const { return __orc_rt_CWrapperFunctionResultEmpty(&R); }
73 
74   /// Create a WrapperFunctionResult with the given size and return a pointer
75   /// to the underlying memory.
76   static WrapperFunctionResult allocate(size_t Size) {
77     WrapperFunctionResult R;
78     R.R = __orc_rt_CWrapperFunctionResultAllocate(Size);
79     return R;
80   }
81 
82   /// Copy from the given char range.
83   static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
84     return __orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size);
85   }
86 
87   /// Copy from the given null-terminated string (includes the null-terminator).
88   static WrapperFunctionResult copyFrom(const char *Source) {
89     return __orc_rt_CreateCWrapperFunctionResultFromString(Source);
90   }
91 
92   /// Copy from the given std::string (includes the null terminator).
93   static WrapperFunctionResult copyFrom(const std::string &Source) {
94     return copyFrom(Source.c_str());
95   }
96 
97   /// Create an out-of-band error by copying the given string.
98   static WrapperFunctionResult createOutOfBandError(const char *Msg) {
99     return __orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg);
100   }
101 
102   /// Create an out-of-band error by copying the given string.
103   static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
104     return createOutOfBandError(Msg.c_str());
105   }
106 
107   /// If this value is an out-of-band error then this returns the error message,
108   /// otherwise returns nullptr.
109   const char *getOutOfBandError() const {
110     return __orc_rt_CWrapperFunctionResultGetOutOfBandError(&R);
111   }
112 
113 private:
114   __orc_rt_CWrapperFunctionResult R;
115 };
116 
117 namespace detail {
118 
119 template <typename SPSArgListT, typename... ArgTs>
120 WrapperFunctionResult
121 serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
122   auto Result = WrapperFunctionResult::allocate(SPSArgListT::size(Args...));
123   SPSOutputBuffer OB(Result.data(), Result.size());
124   if (!SPSArgListT::serialize(OB, Args...))
125     return WrapperFunctionResult::createOutOfBandError(
126         "Error serializing arguments to blob in call");
127   return Result;
128 }
129 
130 template <typename RetT> class WrapperFunctionHandlerCaller {
131 public:
132   template <typename HandlerT, typename ArgTupleT, std::size_t... I>
133   static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
134                              std::index_sequence<I...>) {
135     return std::forward<HandlerT>(H)(std::get<I>(Args)...);
136   }
137 };
138 
139 template <> class WrapperFunctionHandlerCaller<void> {
140 public:
141   template <typename HandlerT, typename ArgTupleT, std::size_t... I>
142   static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
143                        std::index_sequence<I...>) {
144     std::forward<HandlerT>(H)(std::get<I>(Args)...);
145     return SPSEmpty();
146   }
147 };
148 
149 template <typename WrapperFunctionImplT,
150           template <typename> class ResultSerializer, typename... SPSTagTs>
151 class WrapperFunctionHandlerHelper
152     : public WrapperFunctionHandlerHelper<
153           decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
154           ResultSerializer, SPSTagTs...> {};
155 
156 template <typename RetT, typename... ArgTs,
157           template <typename> class ResultSerializer, typename... SPSTagTs>
158 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
159                                    SPSTagTs...> {
160 public:
161   using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
162   using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
163 
164   template <typename HandlerT>
165   static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
166                                      size_t ArgSize) {
167     ArgTuple Args;
168     if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
169       return WrapperFunctionResult::createOutOfBandError(
170           "Could not deserialize arguments for wrapper function call");
171 
172     auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
173         std::forward<HandlerT>(H), Args, ArgIndices{});
174 
175     return ResultSerializer<decltype(HandlerResult)>::serialize(
176         std::move(HandlerResult));
177   }
178 
179 private:
180   template <std::size_t... I>
181   static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
182                           std::index_sequence<I...>) {
183     SPSInputBuffer IB(ArgData, ArgSize);
184     return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
185   }
186 };
187 
188 // Map function pointers to function types.
189 template <typename RetT, typename... ArgTs,
190           template <typename> class ResultSerializer, typename... SPSTagTs>
191 class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
192                                    SPSTagTs...>
193     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
194                                           SPSTagTs...> {};
195 
196 // Map non-const member function types to function types.
197 template <typename ClassT, typename RetT, typename... ArgTs,
198           template <typename> class ResultSerializer, typename... SPSTagTs>
199 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
200                                    SPSTagTs...>
201     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
202                                           SPSTagTs...> {};
203 
204 // Map const member function types to function types.
205 template <typename ClassT, typename RetT, typename... ArgTs,
206           template <typename> class ResultSerializer, typename... SPSTagTs>
207 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
208                                    ResultSerializer, SPSTagTs...>
209     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
210                                           SPSTagTs...> {};
211 
212 template <typename SPSRetTagT, typename RetT> class ResultSerializer {
213 public:
214   static WrapperFunctionResult serialize(RetT Result) {
215     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
216         Result);
217   }
218 };
219 
220 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
221 public:
222   static WrapperFunctionResult serialize(Error Err) {
223     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
224         toSPSSerializable(std::move(Err)));
225   }
226 };
227 
228 template <typename SPSRetTagT, typename T>
229 class ResultSerializer<SPSRetTagT, Expected<T>> {
230 public:
231   static WrapperFunctionResult serialize(Expected<T> E) {
232     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
233         toSPSSerializable(std::move(E)));
234   }
235 };
236 
237 template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
238 public:
239   static void makeSafe(RetT &Result) {}
240 
241   static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
242     SPSInputBuffer IB(ArgData, ArgSize);
243     if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
244       return make_error<StringError>(
245           "Error deserializing return value from blob in call");
246     return Error::success();
247   }
248 };
249 
250 template <> class ResultDeserializer<SPSError, Error> {
251 public:
252   static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
253 
254   static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
255     SPSInputBuffer IB(ArgData, ArgSize);
256     SPSSerializableError BSE;
257     if (!SPSArgList<SPSError>::deserialize(IB, BSE))
258       return make_error<StringError>(
259           "Error deserializing return value from blob in call");
260     Err = fromSPSSerializable(std::move(BSE));
261     return Error::success();
262   }
263 };
264 
265 template <typename SPSTagT, typename T>
266 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
267 public:
268   static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
269 
270   static Error deserialize(Expected<T> &E, const char *ArgData,
271                            size_t ArgSize) {
272     SPSInputBuffer IB(ArgData, ArgSize);
273     SPSSerializableExpected<T> BSE;
274     if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
275       return make_error<StringError>(
276           "Error deserializing return value from blob in call");
277     E = fromSPSSerializable(std::move(BSE));
278     return Error::success();
279   }
280 };
281 
282 } // end namespace detail
283 
284 template <typename SPSSignature> class WrapperFunction;
285 
286 template <typename SPSRetTagT, typename... SPSTagTs>
287 class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
288 private:
289   template <typename RetT>
290   using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
291 
292 public:
293   template <typename RetT, typename... ArgTs>
294   static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) {
295 
296     // RetT might be an Error or Expected value. Set the checked flag now:
297     // we don't want the user to have to check the unused result if this
298     // operation fails.
299     detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
300 
301     if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx))
302       return make_error<StringError>("__orc_rt_jit_dispatch_ctx not set");
303     if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch))
304       return make_error<StringError>("__orc_rt_jit_dispatch not set");
305 
306     auto ArgBuffer =
307         detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
308             Args...);
309     if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
310       return make_error<StringError>(ErrMsg);
311 
312     WrapperFunctionResult ResultBuffer = __orc_rt_jit_dispatch(
313         &__orc_rt_jit_dispatch_ctx, FnTag, ArgBuffer.data(), ArgBuffer.size());
314     if (auto ErrMsg = ResultBuffer.getOutOfBandError())
315       return make_error<StringError>(ErrMsg);
316 
317     return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
318         Result, ResultBuffer.data(), ResultBuffer.size());
319   }
320 
321   template <typename HandlerT>
322   static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
323                                       HandlerT &&Handler) {
324     using WFHH =
325         detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
326                                              ResultSerializer, SPSTagTs...>;
327     return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
328   }
329 
330 private:
331   template <typename T> static const T &makeSerializable(const T &Value) {
332     return Value;
333   }
334 
335   static detail::SPSSerializableError makeSerializable(Error Err) {
336     return detail::toSPSSerializable(std::move(Err));
337   }
338 
339   template <typename T>
340   static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
341     return detail::toSPSSerializable(std::move(E));
342   }
343 };
344 
345 template <typename... SPSTagTs>
346 class WrapperFunction<void(SPSTagTs...)>
347     : private WrapperFunction<SPSEmpty(SPSTagTs...)> {
348 public:
349   template <typename... ArgTs>
350   static Error call(const void *FnTag, const ArgTs &...Args) {
351     SPSEmpty BE;
352     return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...);
353   }
354 
355   using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
356 };
357 
358 /// A function object that takes an ExecutorAddr as its first argument,
359 /// casts that address to a ClassT*, then calls the given method on that
360 /// pointer passing in the remaining function arguments. This utility
361 /// removes some of the boilerplate from writing wrappers for method calls.
362 ///
363 ///   @code{.cpp}
364 ///   class MyClass {
365 ///   public:
366 ///     void myMethod(uint32_t, bool) { ... }
367 ///   };
368 ///
369 ///   // SPS Method signature -- note MyClass object address as first argument.
370 ///   using SPSMyMethodWrapperSignature =
371 ///     SPSTuple<SPSExecutorAddr, uint32_t, bool>;
372 ///
373 ///   WrapperFunctionResult
374 ///   myMethodCallWrapper(const char *ArgData, size_t ArgSize) {
375 ///     return WrapperFunction<SPSMyMethodWrapperSignature>::handle(
376 ///        ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod));
377 ///   }
378 ///   @endcode
379 ///
380 template <typename RetT, typename ClassT, typename... ArgTs>
381 class MethodWrapperHandler {
382 public:
383   using MethodT = RetT (ClassT::*)(ArgTs...);
384   MethodWrapperHandler(MethodT M) : M(M) {}
385   RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) {
386     return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...);
387   }
388 
389 private:
390   MethodT M;
391 };
392 
393 /// Create a MethodWrapperHandler object from the given method pointer.
394 template <typename RetT, typename ClassT, typename... ArgTs>
395 MethodWrapperHandler<RetT, ClassT, ArgTs...>
396 makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
397   return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
398 }
399 
400 } // end namespace __orc_rt
401 
402 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H
403