1 //===------- SimpleRemoteEPC.cpp -- Simple remote executor control --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h"
10 #include "llvm/ExecutionEngine/Orc/EPCGenericJITLinkMemoryManager.h"
11 #include "llvm/ExecutionEngine/Orc/EPCGenericMemoryAccess.h"
12 #include "llvm/Support/FormatVariadic.h"
13 
14 #define DEBUG_TYPE "orc"
15 
16 namespace llvm {
17 namespace orc {
18 namespace shared {
19 
20 template <>
21 class SPSSerializationTraits<SPSRemoteSymbolLookupSetElement,
22                              SymbolLookupSet::value_type> {
23 public:
24   static size_t size(const SymbolLookupSet::value_type &V) {
25     return SPSArgList<SPSString, bool>::size(
26         *V.first, V.second == SymbolLookupFlags::RequiredSymbol);
27   }
28 
29   static bool serialize(SPSOutputBuffer &OB,
30                         const SymbolLookupSet::value_type &V) {
31     return SPSArgList<SPSString, bool>::serialize(
32         OB, *V.first, V.second == SymbolLookupFlags::RequiredSymbol);
33   }
34 };
35 
36 template <>
37 class TrivialSPSSequenceSerialization<SPSRemoteSymbolLookupSetElement,
38                                       SymbolLookupSet> {
39 public:
40   static constexpr bool available = true;
41 };
42 
43 template <>
44 class SPSSerializationTraits<SPSRemoteSymbolLookup,
45                              ExecutorProcessControl::LookupRequest> {
46   using MemberSerialization =
47       SPSArgList<SPSExecutorAddress, SPSRemoteSymbolLookupSet>;
48 
49 public:
50   static size_t size(const ExecutorProcessControl::LookupRequest &LR) {
51     return MemberSerialization::size(ExecutorAddress(LR.Handle), LR.Symbols);
52   }
53 
54   static bool serialize(SPSOutputBuffer &OB,
55                         const ExecutorProcessControl::LookupRequest &LR) {
56     return MemberSerialization::serialize(OB, ExecutorAddress(LR.Handle),
57                                           LR.Symbols);
58   }
59 };
60 
61 } // end namespace shared
62 
63 SimpleRemoteEPC::~SimpleRemoteEPC() {
64   assert(Disconnected && "Destroyed without disconnection");
65 }
66 
67 Expected<tpctypes::DylibHandle>
68 SimpleRemoteEPC::loadDylib(const char *DylibPath) {
69   Expected<tpctypes::DylibHandle> H((tpctypes::DylibHandle()));
70   if (auto Err = callSPSWrapper<shared::SPSLoadDylibSignature>(
71           LoadDylibAddr.getValue(), H, JDI.JITDispatchContextAddress,
72           StringRef(DylibPath), (uint64_t)0))
73     return std::move(Err);
74   return H;
75 }
76 
77 Expected<std::vector<tpctypes::LookupResult>>
78 SimpleRemoteEPC::lookupSymbols(ArrayRef<LookupRequest> Request) {
79   Expected<std::vector<tpctypes::LookupResult>> R(
80       (std::vector<tpctypes::LookupResult>()));
81 
82   if (auto Err = callSPSWrapper<shared::SPSLookupSymbolsSignature>(
83           LookupSymbolsAddr.getValue(), R, JDI.JITDispatchContextAddress,
84           Request))
85     return std::move(Err);
86   return R;
87 }
88 
89 Expected<int32_t> SimpleRemoteEPC::runAsMain(JITTargetAddress MainFnAddr,
90                                              ArrayRef<std::string> Args) {
91   int64_t Result = 0;
92   if (auto Err = callSPSWrapper<shared::SPSRunAsMainSignature>(
93           RunAsMainAddr.getValue(), Result, ExecutorAddress(MainFnAddr), Args))
94     return std::move(Err);
95   return Result;
96 }
97 
98 void SimpleRemoteEPC::callWrapperAsync(SendResultFunction OnComplete,
99                                        JITTargetAddress WrapperFnAddr,
100                                        ArrayRef<char> ArgBuffer) {
101   uint64_t SeqNo;
102   {
103     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
104     SeqNo = getNextSeqNo();
105     assert(!PendingCallWrapperResults.count(SeqNo) && "SeqNo already in use");
106     PendingCallWrapperResults[SeqNo] = std::move(OnComplete);
107   }
108 
109   if (auto Err = T->sendMessage(SimpleRemoteEPCOpcode::CallWrapper, SeqNo,
110                                 ExecutorAddress(WrapperFnAddr), ArgBuffer)) {
111     getExecutionSession().reportError(std::move(Err));
112   }
113 }
114 
115 Error SimpleRemoteEPC::disconnect() {
116   Disconnected = true;
117   T->disconnect();
118   return Error::success();
119 }
120 
121 Expected<SimpleRemoteEPCTransportClient::HandleMessageAction>
122 SimpleRemoteEPC::handleMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
123                                ExecutorAddress TagAddr,
124                                SimpleRemoteEPCArgBytesVector ArgBytes) {
125   using UT = std::underlying_type_t<SimpleRemoteEPCOpcode>;
126   if (static_cast<UT>(OpC) > static_cast<UT>(SimpleRemoteEPCOpcode::LastOpC))
127     return make_error<StringError>("Unexpected opcode",
128                                    inconvertibleErrorCode());
129 
130   switch (OpC) {
131   case SimpleRemoteEPCOpcode::Setup:
132     if (auto Err = handleSetup(SeqNo, TagAddr, std::move(ArgBytes)))
133       return std::move(Err);
134     break;
135   case SimpleRemoteEPCOpcode::Hangup:
136     // FIXME: Put EPC into 'detached' state.
137     return SimpleRemoteEPCTransportClient::EndSession;
138   case SimpleRemoteEPCOpcode::Result:
139     if (auto Err = handleResult(SeqNo, TagAddr, std::move(ArgBytes)))
140       return std::move(Err);
141     break;
142   case SimpleRemoteEPCOpcode::CallWrapper:
143     handleCallWrapper(SeqNo, TagAddr, std::move(ArgBytes));
144     break;
145   }
146   return ContinueSession;
147 }
148 
149 void SimpleRemoteEPC::handleDisconnect(Error Err) {
150   PendingCallWrapperResultsMap TmpPending;
151 
152   {
153     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
154     std::swap(TmpPending, PendingCallWrapperResults);
155   }
156 
157   for (auto &KV : TmpPending)
158     KV.second(
159         shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
160 
161   if (Err) {
162     // FIXME: Move ReportError to EPC.
163     if (ES)
164       ES->reportError(std::move(Err));
165     else
166       logAllUnhandledErrors(std::move(Err), errs(), "SimpleRemoteEPC: ");
167   }
168 }
169 
170 Expected<std::unique_ptr<jitlink::JITLinkMemoryManager>>
171 SimpleRemoteEPC::createMemoryManager() {
172   EPCGenericJITLinkMemoryManager::FuncAddrs FAs;
173   if (auto Err = getBootstrapSymbols(
174           {{FAs.Reserve, "__llvm_orc_memory_reserve"},
175            {FAs.Finalize, "__llvm_orc_memory_finalize"},
176            {FAs.Deallocate, "__llvm_orc_memory_deallocate"}}))
177     return std::move(Err);
178 
179   return std::make_unique<EPCGenericJITLinkMemoryManager>(*this, FAs);
180 }
181 
182 Expected<std::unique_ptr<ExecutorProcessControl::MemoryAccess>>
183 SimpleRemoteEPC::createMemoryAccess() {
184 
185   return nullptr;
186 }
187 
188 Error SimpleRemoteEPC::handleSetup(uint64_t SeqNo, ExecutorAddress TagAddr,
189                                    SimpleRemoteEPCArgBytesVector ArgBytes) {
190   if (SeqNo != 0)
191     return make_error<StringError>("Setup packet SeqNo not zero",
192                                    inconvertibleErrorCode());
193 
194   if (TagAddr)
195     return make_error<StringError>("Setup packet TagAddr not zero",
196                                    inconvertibleErrorCode());
197 
198   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
199   auto I = PendingCallWrapperResults.find(0);
200   assert(PendingCallWrapperResults.size() == 1 &&
201          I != PendingCallWrapperResults.end() &&
202          "Setup message handler not connectly set up");
203   auto SetupMsgHandler = std::move(I->second);
204   PendingCallWrapperResults.erase(I);
205 
206   auto WFR =
207       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
208   SetupMsgHandler(std::move(WFR));
209   return Error::success();
210 }
211 
212 void SimpleRemoteEPC::prepareToReceiveSetupMessage(
213     std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> &ExecInfoP) {
214   PendingCallWrapperResults[0] =
215       [&](shared::WrapperFunctionResult SetupMsgBytes) {
216         if (const char *ErrMsg = SetupMsgBytes.getOutOfBandError()) {
217           ExecInfoP.set_value(
218               make_error<StringError>(ErrMsg, inconvertibleErrorCode()));
219           return;
220         }
221         using SPSSerialize =
222             shared::SPSArgList<shared::SPSSimpleRemoteEPCExecutorInfo>;
223         shared::SPSInputBuffer IB(SetupMsgBytes.data(), SetupMsgBytes.size());
224         SimpleRemoteEPCExecutorInfo EI;
225         if (SPSSerialize::deserialize(IB, EI))
226           ExecInfoP.set_value(EI);
227         else
228           ExecInfoP.set_value(make_error<StringError>(
229               "Could not deserialize setup message", inconvertibleErrorCode()));
230       };
231 }
232 
233 Error SimpleRemoteEPC::setup(std::unique_ptr<SimpleRemoteEPCTransport> T,
234                              SimpleRemoteEPCExecutorInfo EI) {
235   using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
236   LLVM_DEBUG({
237     dbgs() << "SimpleRemoteEPC received setup message:\n"
238            << "  Triple: " << EI.TargetTriple << "\n"
239            << "  Page size: " << EI.PageSize << "\n"
240            << "  Bootstrap symbols:\n";
241     for (const auto &KV : EI.BootstrapSymbols)
242       dbgs() << "    " << KV.first() << ": "
243              << formatv("{0:x16}", KV.second.getValue()) << "\n";
244   });
245   this->T = std::move(T);
246   TargetTriple = Triple(EI.TargetTriple);
247   PageSize = EI.PageSize;
248   BootstrapSymbols = std::move(EI.BootstrapSymbols);
249 
250   if (auto Err = getBootstrapSymbols(
251           {{JDI.JITDispatchContextAddress, ExecutorSessionObjectName},
252            {JDI.JITDispatchFunctionAddress, DispatchFnName},
253            {LoadDylibAddr, "__llvm_orc_load_dylib"},
254            {LookupSymbolsAddr, "__llvm_orc_lookup_symbols"},
255            {RunAsMainAddr, "__llvm_orc_run_as_main"}}))
256     return Err;
257 
258   if (auto MemMgr = createMemoryManager()) {
259     OwnedMemMgr = std::move(*MemMgr);
260     this->MemMgr = OwnedMemMgr.get();
261   } else
262     return MemMgr.takeError();
263 
264   if (auto MemAccess = createMemoryAccess()) {
265     OwnedMemAccess = std::move(*MemAccess);
266     this->MemAccess = OwnedMemAccess.get();
267   } else
268     return MemAccess.takeError();
269 
270   return Error::success();
271 }
272 
273 Error SimpleRemoteEPC::handleResult(uint64_t SeqNo, ExecutorAddress TagAddr,
274                                     SimpleRemoteEPCArgBytesVector ArgBytes) {
275   SendResultFunction SendResult;
276 
277   if (TagAddr)
278     return make_error<StringError>("Unexpected TagAddr in result message",
279                                    inconvertibleErrorCode());
280 
281   {
282     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
283     auto I = PendingCallWrapperResults.find(SeqNo);
284     if (I == PendingCallWrapperResults.end())
285       return make_error<StringError>("No call for sequence number " +
286                                          Twine(SeqNo),
287                                      inconvertibleErrorCode());
288     SendResult = std::move(I->second);
289     PendingCallWrapperResults.erase(I);
290     releaseSeqNo(SeqNo);
291   }
292 
293   auto WFR =
294       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
295   SendResult(std::move(WFR));
296   return Error::success();
297 }
298 
299 void SimpleRemoteEPC::handleCallWrapper(
300     uint64_t RemoteSeqNo, ExecutorAddress TagAddr,
301     SimpleRemoteEPCArgBytesVector ArgBytes) {
302   assert(ES && "No ExecutionSession attached");
303   ES->runJITDispatchHandler(
304       [this, RemoteSeqNo](shared::WrapperFunctionResult WFR) {
305         if (auto Err =
306                 T->sendMessage(SimpleRemoteEPCOpcode::Result, RemoteSeqNo,
307                                ExecutorAddress(), {WFR.data(), WFR.size()}))
308           getExecutionSession().reportError(std::move(Err));
309       },
310       TagAddr.getValue(), ArgBytes);
311 }
312 
313 } // end namespace orc
314 } // end namespace llvm
315