1 //===------- SimpleRemoteEPC.cpp -- Simple remote executor control --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h"
10 #include "llvm/ExecutionEngine/Orc/EPCGenericJITLinkMemoryManager.h"
11 #include "llvm/ExecutionEngine/Orc/EPCGenericMemoryAccess.h"
12 #include "llvm/ExecutionEngine/Orc/Shared/OrcRTBridge.h"
13 #include "llvm/Support/FormatVariadic.h"
14 
15 #define DEBUG_TYPE "orc"
16 
17 namespace llvm {
18 namespace orc {
19 
20 SimpleRemoteEPC::~SimpleRemoteEPC() {
21 #ifndef NDEBUG
22   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
23   assert(Disconnected && "Destroyed without disconnection");
24 #endif // NDEBUG
25 }
26 
27 Expected<tpctypes::DylibHandle>
28 SimpleRemoteEPC::loadDylib(const char *DylibPath) {
29   return DylibMgr->open(DylibPath, 0);
30 }
31 
32 Expected<std::vector<tpctypes::LookupResult>>
33 SimpleRemoteEPC::lookupSymbols(ArrayRef<LookupRequest> Request) {
34   std::vector<tpctypes::LookupResult> Result;
35 
36   for (auto &Element : Request) {
37     if (auto R = DylibMgr->lookup(Element.Handle, Element.Symbols)) {
38       Result.push_back({});
39       Result.back().reserve(R->size());
40       for (auto Addr : *R)
41         Result.back().push_back(Addr.getValue());
42     } else
43       return R.takeError();
44   }
45   return std::move(Result);
46 }
47 
48 Expected<int32_t> SimpleRemoteEPC::runAsMain(ExecutorAddr MainFnAddr,
49                                              ArrayRef<std::string> Args) {
50   int64_t Result = 0;
51   if (auto Err = callSPSWrapper<rt::SPSRunAsMainSignature>(
52           RunAsMainAddr, Result, ExecutorAddr(MainFnAddr), Args))
53     return std::move(Err);
54   return Result;
55 }
56 
57 void SimpleRemoteEPC::callWrapperAsync(ExecutorAddr WrapperFnAddr,
58                                        IncomingWFRHandler OnComplete,
59                                        ArrayRef<char> ArgBuffer) {
60   uint64_t SeqNo;
61   {
62     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
63     SeqNo = getNextSeqNo();
64     assert(!PendingCallWrapperResults.count(SeqNo) && "SeqNo already in use");
65     PendingCallWrapperResults[SeqNo] = std::move(OnComplete);
66   }
67 
68   if (auto Err = sendMessage(SimpleRemoteEPCOpcode::CallWrapper, SeqNo,
69                              WrapperFnAddr, ArgBuffer)) {
70     getExecutionSession().reportError(std::move(Err));
71   }
72 }
73 
74 Error SimpleRemoteEPC::disconnect() {
75   T->disconnect();
76   std::unique_lock<std::mutex> Lock(SimpleRemoteEPCMutex);
77   DisconnectCV.wait(Lock, [this] { return Disconnected; });
78   return std::move(DisconnectErr);
79 }
80 
81 Expected<SimpleRemoteEPCTransportClient::HandleMessageAction>
82 SimpleRemoteEPC::handleMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
83                                ExecutorAddr TagAddr,
84                                SimpleRemoteEPCArgBytesVector ArgBytes) {
85 
86   LLVM_DEBUG({
87     dbgs() << "SimpleRemoteEPC::handleMessage: opc = ";
88     switch (OpC) {
89     case SimpleRemoteEPCOpcode::Setup:
90       dbgs() << "Setup";
91       assert(SeqNo == 0 && "Non-zero SeqNo for Setup?");
92       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Setup?");
93       break;
94     case SimpleRemoteEPCOpcode::Hangup:
95       dbgs() << "Hangup";
96       assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
97       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
98       break;
99     case SimpleRemoteEPCOpcode::Result:
100       dbgs() << "Result";
101       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
102       break;
103     case SimpleRemoteEPCOpcode::CallWrapper:
104       dbgs() << "CallWrapper";
105       break;
106     }
107     dbgs() << ", seqno = " << SeqNo
108            << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
109            << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
110            << " bytes\n";
111   });
112 
113   using UT = std::underlying_type_t<SimpleRemoteEPCOpcode>;
114   if (static_cast<UT>(OpC) > static_cast<UT>(SimpleRemoteEPCOpcode::LastOpC))
115     return make_error<StringError>("Unexpected opcode",
116                                    inconvertibleErrorCode());
117 
118   switch (OpC) {
119   case SimpleRemoteEPCOpcode::Setup:
120     if (auto Err = handleSetup(SeqNo, TagAddr, std::move(ArgBytes)))
121       return std::move(Err);
122     break;
123   case SimpleRemoteEPCOpcode::Hangup:
124     T->disconnect();
125     if (auto Err = handleHangup(std::move(ArgBytes)))
126       return std::move(Err);
127     return EndSession;
128   case SimpleRemoteEPCOpcode::Result:
129     if (auto Err = handleResult(SeqNo, TagAddr, std::move(ArgBytes)))
130       return std::move(Err);
131     break;
132   case SimpleRemoteEPCOpcode::CallWrapper:
133     handleCallWrapper(SeqNo, TagAddr, std::move(ArgBytes));
134     break;
135   }
136   return ContinueSession;
137 }
138 
139 void SimpleRemoteEPC::handleDisconnect(Error Err) {
140   LLVM_DEBUG({
141     dbgs() << "SimpleRemoteEPC::handleDisconnect: "
142            << (Err ? "failure" : "success") << "\n";
143   });
144 
145   PendingCallWrapperResultsMap TmpPending;
146 
147   {
148     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
149     std::swap(TmpPending, PendingCallWrapperResults);
150   }
151 
152   for (auto &KV : TmpPending)
153     KV.second(
154         shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
155 
156   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
157   DisconnectErr = joinErrors(std::move(DisconnectErr), std::move(Err));
158   Disconnected = true;
159   DisconnectCV.notify_all();
160 }
161 
162 Expected<std::unique_ptr<jitlink::JITLinkMemoryManager>>
163 SimpleRemoteEPC::createMemoryManager() {
164   EPCGenericJITLinkMemoryManager::SymbolAddrs SAs;
165   if (auto Err = getBootstrapSymbols(
166           {{SAs.Allocator, rt::SimpleExecutorMemoryManagerInstanceName},
167            {SAs.Reserve, rt::SimpleExecutorMemoryManagerReserveWrapperName},
168            {SAs.Finalize, rt::SimpleExecutorMemoryManagerFinalizeWrapperName},
169            {SAs.Deallocate,
170             rt::SimpleExecutorMemoryManagerDeallocateWrapperName}}))
171     return std::move(Err);
172 
173   return std::make_unique<EPCGenericJITLinkMemoryManager>(*this, SAs);
174 }
175 
176 Expected<std::unique_ptr<ExecutorProcessControl::MemoryAccess>>
177 SimpleRemoteEPC::createMemoryAccess() {
178 
179   return nullptr;
180 }
181 
182 Error SimpleRemoteEPC::sendMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
183                                    ExecutorAddr TagAddr,
184                                    ArrayRef<char> ArgBytes) {
185   assert(OpC != SimpleRemoteEPCOpcode::Setup &&
186          "SimpleRemoteEPC sending Setup message? That's the wrong direction.");
187 
188   LLVM_DEBUG({
189     dbgs() << "SimpleRemoteEPC::sendMessage: opc = ";
190     switch (OpC) {
191     case SimpleRemoteEPCOpcode::Hangup:
192       dbgs() << "Hangup";
193       assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
194       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
195       break;
196     case SimpleRemoteEPCOpcode::Result:
197       dbgs() << "Result";
198       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
199       break;
200     case SimpleRemoteEPCOpcode::CallWrapper:
201       dbgs() << "CallWrapper";
202       break;
203     default:
204       llvm_unreachable("Invalid opcode");
205     }
206     dbgs() << ", seqno = " << SeqNo
207            << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
208            << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
209            << " bytes\n";
210   });
211   auto Err = T->sendMessage(OpC, SeqNo, TagAddr, ArgBytes);
212   LLVM_DEBUG({
213     if (Err)
214       dbgs() << "  \\--> SimpleRemoteEPC::sendMessage failed\n";
215   });
216   return Err;
217 }
218 
219 Error SimpleRemoteEPC::handleSetup(uint64_t SeqNo, ExecutorAddr TagAddr,
220                                    SimpleRemoteEPCArgBytesVector ArgBytes) {
221   if (SeqNo != 0)
222     return make_error<StringError>("Setup packet SeqNo not zero",
223                                    inconvertibleErrorCode());
224 
225   if (TagAddr)
226     return make_error<StringError>("Setup packet TagAddr not zero",
227                                    inconvertibleErrorCode());
228 
229   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
230   auto I = PendingCallWrapperResults.find(0);
231   assert(PendingCallWrapperResults.size() == 1 &&
232          I != PendingCallWrapperResults.end() &&
233          "Setup message handler not connectly set up");
234   auto SetupMsgHandler = std::move(I->second);
235   PendingCallWrapperResults.erase(I);
236 
237   auto WFR =
238       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
239   SetupMsgHandler(std::move(WFR));
240   return Error::success();
241 }
242 
243 Error SimpleRemoteEPC::setup() {
244   using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
245 
246   std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> EIP;
247   auto EIF = EIP.get_future();
248 
249   // Prepare a handler for the setup packet.
250   PendingCallWrapperResults[0] =
251     RunInPlace()(
252       [&](shared::WrapperFunctionResult SetupMsgBytes) {
253         if (const char *ErrMsg = SetupMsgBytes.getOutOfBandError()) {
254           EIP.set_value(
255               make_error<StringError>(ErrMsg, inconvertibleErrorCode()));
256           return;
257         }
258         using SPSSerialize =
259             shared::SPSArgList<shared::SPSSimpleRemoteEPCExecutorInfo>;
260         shared::SPSInputBuffer IB(SetupMsgBytes.data(), SetupMsgBytes.size());
261         SimpleRemoteEPCExecutorInfo EI;
262         if (SPSSerialize::deserialize(IB, EI))
263           EIP.set_value(EI);
264         else
265           EIP.set_value(make_error<StringError>(
266               "Could not deserialize setup message", inconvertibleErrorCode()));
267       });
268 
269   // Start the transport.
270   if (auto Err = T->start())
271     return Err;
272 
273   // Wait for setup packet to arrive.
274   auto EI = EIF.get();
275   if (!EI) {
276     T->disconnect();
277     return EI.takeError();
278   }
279 
280   LLVM_DEBUG({
281     dbgs() << "SimpleRemoteEPC received setup message:\n"
282            << "  Triple: " << EI->TargetTriple << "\n"
283            << "  Page size: " << EI->PageSize << "\n"
284            << "  Bootstrap symbols:\n";
285     for (const auto &KV : EI->BootstrapSymbols)
286       dbgs() << "    " << KV.first() << ": "
287              << formatv("{0:x16}", KV.second.getValue()) << "\n";
288   });
289   TargetTriple = Triple(EI->TargetTriple);
290   PageSize = EI->PageSize;
291   BootstrapSymbols = std::move(EI->BootstrapSymbols);
292 
293   if (auto Err = getBootstrapSymbols(
294           {{JDI.JITDispatchContext, ExecutorSessionObjectName},
295            {JDI.JITDispatchFunction, DispatchFnName},
296            {RunAsMainAddr, rt::RunAsMainWrapperName}}))
297     return Err;
298 
299   if (auto DM =
300           EPCGenericDylibManager::CreateWithDefaultBootstrapSymbols(*this))
301     DylibMgr = std::make_unique<EPCGenericDylibManager>(std::move(*DM));
302   else
303     return DM.takeError();
304 
305   if (auto MemMgr = createMemoryManager()) {
306     OwnedMemMgr = std::move(*MemMgr);
307     this->MemMgr = OwnedMemMgr.get();
308   } else
309     return MemMgr.takeError();
310 
311   if (auto MemAccess = createMemoryAccess()) {
312     OwnedMemAccess = std::move(*MemAccess);
313     this->MemAccess = OwnedMemAccess.get();
314   } else
315     return MemAccess.takeError();
316 
317   return Error::success();
318 }
319 
320 Error SimpleRemoteEPC::handleResult(uint64_t SeqNo, ExecutorAddr TagAddr,
321                                     SimpleRemoteEPCArgBytesVector ArgBytes) {
322   IncomingWFRHandler SendResult;
323 
324   if (TagAddr)
325     return make_error<StringError>("Unexpected TagAddr in result message",
326                                    inconvertibleErrorCode());
327 
328   {
329     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
330     auto I = PendingCallWrapperResults.find(SeqNo);
331     if (I == PendingCallWrapperResults.end())
332       return make_error<StringError>("No call for sequence number " +
333                                          Twine(SeqNo),
334                                      inconvertibleErrorCode());
335     SendResult = std::move(I->second);
336     PendingCallWrapperResults.erase(I);
337     releaseSeqNo(SeqNo);
338   }
339 
340   auto WFR =
341       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
342   SendResult(std::move(WFR));
343   return Error::success();
344 }
345 
346 void SimpleRemoteEPC::handleCallWrapper(
347     uint64_t RemoteSeqNo, ExecutorAddr TagAddr,
348     SimpleRemoteEPCArgBytesVector ArgBytes) {
349   assert(ES && "No ExecutionSession attached");
350   ES->runJITDispatchHandler(
351       [this, RemoteSeqNo](shared::WrapperFunctionResult WFR) {
352         if (auto Err = sendMessage(SimpleRemoteEPCOpcode::Result, RemoteSeqNo,
353                                    ExecutorAddr(), {WFR.data(), WFR.size()}))
354           getExecutionSession().reportError(std::move(Err));
355       },
356       TagAddr.getValue(), ArgBytes);
357 }
358 
359 Error SimpleRemoteEPC::handleHangup(SimpleRemoteEPCArgBytesVector ArgBytes) {
360   using namespace llvm::orc::shared;
361   auto WFR = WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
362   if (const char *ErrMsg = WFR.getOutOfBandError())
363     return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
364 
365   detail::SPSSerializableError Info;
366   SPSInputBuffer IB(WFR.data(), WFR.size());
367   if (!SPSArgList<SPSError>::deserialize(IB, Info))
368     return make_error<StringError>("Could not deserialize hangup info",
369                                    inconvertibleErrorCode());
370   return fromSPSSerializable(std::move(Info));
371 }
372 
373 } // end namespace orc
374 } // end namespace llvm
375