12487db1fSLang Hames //===- ExecutionSessionWrapperFunctionCallsTest.cpp -- Test wrapper calls -===//
22487db1fSLang Hames //
32487db1fSLang Hames // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42487db1fSLang Hames // See https://llvm.org/LICENSE.txt for license information.
52487db1fSLang Hames // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62487db1fSLang Hames //
72487db1fSLang Hames //===----------------------------------------------------------------------===//
82487db1fSLang Hames
92487db1fSLang Hames #include "llvm/ExecutionEngine/Orc/Core.h"
102487db1fSLang Hames #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
112487db1fSLang Hames #include "llvm/Support/MSVCErrorWorkarounds.h"
122487db1fSLang Hames #include "llvm/Testing/Support/Error.h"
132487db1fSLang Hames #include "gtest/gtest.h"
142487db1fSLang Hames
152487db1fSLang Hames #include <future>
162487db1fSLang Hames
172487db1fSLang Hames using namespace llvm;
182487db1fSLang Hames using namespace llvm::orc;
192487db1fSLang Hames using namespace llvm::orc::shared;
202487db1fSLang Hames
addWrapper(const char * ArgData,size_t ArgSize)21*213666f8SLang Hames static llvm::orc::shared::CWrapperFunctionResult addWrapper(const char *ArgData,
22*213666f8SLang Hames size_t ArgSize) {
232487db1fSLang Hames return WrapperFunction<int32_t(int32_t, int32_t)>::handle(
242487db1fSLang Hames ArgData, ArgSize, [](int32_t X, int32_t Y) { return X + Y; })
252487db1fSLang Hames .release();
262487db1fSLang Hames }
272487db1fSLang Hames
addAsyncWrapper(unique_function<void (int32_t)> SendResult,int32_t X,int32_t Y)282487db1fSLang Hames static void addAsyncWrapper(unique_function<void(int32_t)> SendResult,
292487db1fSLang Hames int32_t X, int32_t Y) {
302487db1fSLang Hames SendResult(X + Y);
312487db1fSLang Hames }
322487db1fSLang Hames
33*213666f8SLang Hames static llvm::orc::shared::CWrapperFunctionResult
voidWrapper(const char * ArgData,size_t ArgSize)348a367502SLang Hames voidWrapper(const char *ArgData, size_t ArgSize) {
358a367502SLang Hames return WrapperFunction<void()>::handle(ArgData, ArgSize, []() {}).release();
368a367502SLang Hames }
378a367502SLang Hames
TEST(ExecutionSessionWrapperFunctionCalls,RunWrapperTemplate)382487db1fSLang Hames TEST(ExecutionSessionWrapperFunctionCalls, RunWrapperTemplate) {
392487db1fSLang Hames ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create()));
402487db1fSLang Hames
412487db1fSLang Hames int32_t Result;
422487db1fSLang Hames EXPECT_THAT_ERROR(ES.callSPSWrapper<int32_t(int32_t, int32_t)>(
4321a06254SLang Hames ExecutorAddr::fromPtr(addWrapper), Result, 2, 3),
442487db1fSLang Hames Succeeded());
452487db1fSLang Hames EXPECT_EQ(Result, 5);
4619b4e3cfSLang Hames cantFail(ES.endSession());
472487db1fSLang Hames }
482487db1fSLang Hames
TEST(ExecutionSessionWrapperFunctionCalls,RunVoidWrapperAsyncTemplate)498a367502SLang Hames TEST(ExecutionSessionWrapperFunctionCalls, RunVoidWrapperAsyncTemplate) {
508a367502SLang Hames ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create()));
518a367502SLang Hames
528a367502SLang Hames std::promise<MSVCPError> RP;
53da7f993aSLang Hames ES.callSPSWrapperAsync<void()>(ExecutorAddr::fromPtr(voidWrapper),
548a367502SLang Hames [&](Error SerializationErr) {
558a367502SLang Hames RP.set_value(std::move(SerializationErr));
56da7f993aSLang Hames });
578a367502SLang Hames Error Err = RP.get_future().get();
588a367502SLang Hames EXPECT_THAT_ERROR(std::move(Err), Succeeded());
5919b4e3cfSLang Hames cantFail(ES.endSession());
608a367502SLang Hames }
618a367502SLang Hames
TEST(ExecutionSessionWrapperFunctionCalls,RunNonVoidWrapperAsyncTemplate)628a367502SLang Hames TEST(ExecutionSessionWrapperFunctionCalls, RunNonVoidWrapperAsyncTemplate) {
632487db1fSLang Hames ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create()));
642487db1fSLang Hames
652487db1fSLang Hames std::promise<MSVCPExpected<int32_t>> RP;
668a367502SLang Hames ES.callSPSWrapperAsync<int32_t(int32_t, int32_t)>(
67da7f993aSLang Hames ExecutorAddr::fromPtr(addWrapper),
682487db1fSLang Hames [&](Error SerializationErr, int32_t R) {
692487db1fSLang Hames if (SerializationErr)
702487db1fSLang Hames RP.set_value(std::move(SerializationErr));
712487db1fSLang Hames RP.set_value(std::move(R));
722487db1fSLang Hames },
73da7f993aSLang Hames 2, 3);
742487db1fSLang Hames Expected<int32_t> Result = RP.get_future().get();
752487db1fSLang Hames EXPECT_THAT_EXPECTED(Result, HasValue(5));
7619b4e3cfSLang Hames cantFail(ES.endSession());
772487db1fSLang Hames }
782487db1fSLang Hames
TEST(ExecutionSessionWrapperFunctionCalls,RegisterAsyncHandlerAndRun)792487db1fSLang Hames TEST(ExecutionSessionWrapperFunctionCalls, RegisterAsyncHandlerAndRun) {
802487db1fSLang Hames
812487db1fSLang Hames constexpr JITTargetAddress AddAsyncTagAddr = 0x01;
822487db1fSLang Hames
832487db1fSLang Hames ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create()));
842487db1fSLang Hames auto &JD = ES.createBareJITDylib("JD");
852487db1fSLang Hames
862487db1fSLang Hames auto AddAsyncTag = ES.intern("addAsync_tag");
872487db1fSLang Hames cantFail(JD.define(absoluteSymbols(
882487db1fSLang Hames {{AddAsyncTag,
892487db1fSLang Hames JITEvaluatedSymbol(AddAsyncTagAddr, JITSymbolFlags::Exported)}})));
902487db1fSLang Hames
912487db1fSLang Hames ExecutionSession::JITDispatchHandlerAssociationMap Associations;
922487db1fSLang Hames
932487db1fSLang Hames Associations[AddAsyncTag] =
942487db1fSLang Hames ES.wrapAsyncWithSPS<int32_t(int32_t, int32_t)>(addAsyncWrapper);
952487db1fSLang Hames
962487db1fSLang Hames cantFail(ES.registerJITDispatchHandlers(JD, std::move(Associations)));
972487db1fSLang Hames
982487db1fSLang Hames std::promise<int32_t> RP;
992487db1fSLang Hames auto RF = RP.get_future();
1002487db1fSLang Hames
1012487db1fSLang Hames using ArgSerialization = SPSArgList<int32_t, int32_t>;
1022487db1fSLang Hames size_t ArgBufferSize = ArgSerialization::size(1, 2);
1038b117830SLang Hames auto ArgBuffer = WrapperFunctionResult::allocate(ArgBufferSize);
1048b117830SLang Hames SPSOutputBuffer OB(ArgBuffer.data(), ArgBuffer.size());
1052487db1fSLang Hames EXPECT_TRUE(ArgSerialization::serialize(OB, 1, 2));
1062487db1fSLang Hames
1072487db1fSLang Hames ES.runJITDispatchHandler(
1082487db1fSLang Hames [&](WrapperFunctionResult ResultBuffer) {
1092487db1fSLang Hames int32_t Result;
1102487db1fSLang Hames SPSInputBuffer IB(ResultBuffer.data(), ResultBuffer.size());
1112487db1fSLang Hames EXPECT_TRUE(SPSArgList<int32_t>::deserialize(IB, Result));
1122487db1fSLang Hames RP.set_value(Result);
1132487db1fSLang Hames },
1142487db1fSLang Hames AddAsyncTagAddr, ArrayRef<char>(ArgBuffer.data(), ArgBuffer.size()));
1152487db1fSLang Hames
1162487db1fSLang Hames EXPECT_EQ(RF.get(), (int32_t)3);
1172487db1fSLang Hames
1182487db1fSLang Hames cantFail(ES.endSession());
1192487db1fSLang Hames }
120