1 //===------- SimpleEPCServer.cpp - EPC over simple abstract channel -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ExecutionEngine/Orc/TargetProcess/SimpleRemoteEPCServer.h"
10 
11 #include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h"
12 #include "llvm/Support/FormatVariadic.h"
13 #include "llvm/Support/Host.h"
14 #include "llvm/Support/Process.h"
15 
16 #include "OrcRTBootstrap.h"
17 
18 #define DEBUG_TYPE "orc"
19 
20 using namespace llvm::orc::shared;
21 
22 namespace llvm {
23 namespace orc {
24 
25 ExecutorBootstrapService::~ExecutorBootstrapService() {}
26 
27 StringMap<ExecutorAddr> SimpleRemoteEPCServer::defaultBootstrapSymbols() {
28   StringMap<ExecutorAddr> DBS;
29   rt_bootstrap::addTo(DBS);
30   return DBS;
31 }
32 
33 Expected<SimpleRemoteEPCTransportClient::HandleMessageAction>
34 SimpleRemoteEPCServer::handleMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
35                                      ExecutorAddr TagAddr,
36                                      SimpleRemoteEPCArgBytesVector ArgBytes) {
37 
38   LLVM_DEBUG({
39     dbgs() << "SimpleRemoteEPCServer::handleMessage: opc = ";
40     switch (OpC) {
41     case SimpleRemoteEPCOpcode::Setup:
42       dbgs() << "Setup";
43       assert(SeqNo == 0 && "Non-zero SeqNo for Setup?");
44       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Setup?");
45       break;
46     case SimpleRemoteEPCOpcode::Hangup:
47       dbgs() << "Hangup";
48       assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
49       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
50       break;
51     case SimpleRemoteEPCOpcode::Result:
52       dbgs() << "Result";
53       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
54       break;
55     case SimpleRemoteEPCOpcode::CallWrapper:
56       dbgs() << "CallWrapper";
57       break;
58     }
59     dbgs() << ", seqno = " << SeqNo
60            << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
61            << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
62            << " bytes\n";
63   });
64 
65   using UT = std::underlying_type_t<SimpleRemoteEPCOpcode>;
66   if (static_cast<UT>(OpC) > static_cast<UT>(SimpleRemoteEPCOpcode::LastOpC))
67     return make_error<StringError>("Unexpected opcode",
68                                    inconvertibleErrorCode());
69 
70   // TODO: Clean detach message?
71   switch (OpC) {
72   case SimpleRemoteEPCOpcode::Setup:
73     return make_error<StringError>("Unexpected Setup opcode",
74                                    inconvertibleErrorCode());
75   case SimpleRemoteEPCOpcode::Hangup:
76     return SimpleRemoteEPCTransportClient::EndSession;
77   case SimpleRemoteEPCOpcode::Result:
78     if (auto Err = handleResult(SeqNo, TagAddr, std::move(ArgBytes)))
79       return std::move(Err);
80     break;
81   case SimpleRemoteEPCOpcode::CallWrapper:
82     handleCallWrapper(SeqNo, TagAddr, std::move(ArgBytes));
83     break;
84   }
85   return ContinueSession;
86 }
87 
88 Error SimpleRemoteEPCServer::waitForDisconnect() {
89   std::unique_lock<std::mutex> Lock(ServerStateMutex);
90   ShutdownCV.wait(Lock, [this]() { return RunState == ServerShutDown; });
91   return std::move(ShutdownErr);
92 }
93 
94 void SimpleRemoteEPCServer::handleDisconnect(Error Err) {
95   PendingJITDispatchResultsMap TmpPending;
96 
97   {
98     std::lock_guard<std::mutex> Lock(ServerStateMutex);
99     std::swap(TmpPending, PendingJITDispatchResults);
100     RunState = ServerShuttingDown;
101   }
102 
103   // Send out-of-band errors to any waiting threads.
104   for (auto &KV : TmpPending)
105     KV.second->set_value(
106         shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
107 
108   // Wait for dispatcher to clear.
109   D->shutdown();
110 
111   // Shut down services.
112   while (!Services.empty()) {
113     ShutdownErr =
114       joinErrors(std::move(ShutdownErr), Services.back()->shutdown());
115     Services.pop_back();
116   }
117 
118   std::lock_guard<std::mutex> Lock(ServerStateMutex);
119   ShutdownErr = joinErrors(std::move(ShutdownErr), std::move(Err));
120   RunState = ServerShutDown;
121   ShutdownCV.notify_all();
122 }
123 
124 Error SimpleRemoteEPCServer::sendMessage(SimpleRemoteEPCOpcode OpC,
125                                          uint64_t SeqNo, ExecutorAddr TagAddr,
126                                          ArrayRef<char> ArgBytes) {
127 
128   LLVM_DEBUG({
129     dbgs() << "SimpleRemoteEPCServer::sendMessage: opc = ";
130     switch (OpC) {
131     case SimpleRemoteEPCOpcode::Setup:
132       dbgs() << "Setup";
133       assert(SeqNo == 0 && "Non-zero SeqNo for Setup?");
134       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Setup?");
135       break;
136     case SimpleRemoteEPCOpcode::Hangup:
137       dbgs() << "Hangup";
138       assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
139       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
140       break;
141     case SimpleRemoteEPCOpcode::Result:
142       dbgs() << "Result";
143       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
144       break;
145     case SimpleRemoteEPCOpcode::CallWrapper:
146       dbgs() << "CallWrapper";
147       break;
148     }
149     dbgs() << ", seqno = " << SeqNo
150            << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
151            << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
152            << " bytes\n";
153   });
154   auto Err = T->sendMessage(OpC, SeqNo, TagAddr, ArgBytes);
155   LLVM_DEBUG({
156     if (Err)
157       dbgs() << "  \\--> SimpleRemoteEPC::sendMessage failed\n";
158   });
159   return Err;
160 }
161 
162 Error SimpleRemoteEPCServer::sendSetupMessage(
163     StringMap<ExecutorAddr> BootstrapSymbols) {
164 
165   using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
166 
167   std::vector<char> SetupPacket;
168   SimpleRemoteEPCExecutorInfo EI;
169   EI.TargetTriple = sys::getProcessTriple();
170   if (auto PageSize = sys::Process::getPageSize())
171     EI.PageSize = *PageSize;
172   else
173     return PageSize.takeError();
174   EI.BootstrapSymbols = std::move(BootstrapSymbols);
175 
176   assert(!EI.BootstrapSymbols.count(ExecutorSessionObjectName) &&
177          "Dispatch context name should not be set");
178   assert(!EI.BootstrapSymbols.count(DispatchFnName) &&
179          "Dispatch function name should not be set");
180   EI.BootstrapSymbols[ExecutorSessionObjectName] = ExecutorAddr::fromPtr(this);
181   EI.BootstrapSymbols[DispatchFnName] = ExecutorAddr::fromPtr(jitDispatchEntry);
182 
183   using SPSSerialize =
184       shared::SPSArgList<shared::SPSSimpleRemoteEPCExecutorInfo>;
185   auto SetupPacketBytes =
186       shared::WrapperFunctionResult::allocate(SPSSerialize::size(EI));
187   shared::SPSOutputBuffer OB(SetupPacketBytes.data(), SetupPacketBytes.size());
188   if (!SPSSerialize::serialize(OB, EI))
189     return make_error<StringError>("Could not send setup packet",
190                                    inconvertibleErrorCode());
191 
192   return sendMessage(SimpleRemoteEPCOpcode::Setup, 0, ExecutorAddr(),
193                      {SetupPacketBytes.data(), SetupPacketBytes.size()});
194 }
195 
196 Error SimpleRemoteEPCServer::handleResult(
197     uint64_t SeqNo, ExecutorAddr TagAddr,
198     SimpleRemoteEPCArgBytesVector ArgBytes) {
199   std::promise<shared::WrapperFunctionResult> *P = nullptr;
200   {
201     std::lock_guard<std::mutex> Lock(ServerStateMutex);
202     auto I = PendingJITDispatchResults.find(SeqNo);
203     if (I == PendingJITDispatchResults.end())
204       return make_error<StringError>("No call for sequence number " +
205                                          Twine(SeqNo),
206                                      inconvertibleErrorCode());
207     P = I->second;
208     PendingJITDispatchResults.erase(I);
209     releaseSeqNo(SeqNo);
210   }
211   auto R = shared::WrapperFunctionResult::allocate(ArgBytes.size());
212   memcpy(R.data(), ArgBytes.data(), ArgBytes.size());
213   P->set_value(std::move(R));
214   return Error::success();
215 }
216 
217 void SimpleRemoteEPCServer::handleCallWrapper(
218     uint64_t RemoteSeqNo, ExecutorAddr TagAddr,
219     SimpleRemoteEPCArgBytesVector ArgBytes) {
220   D->dispatch([this, RemoteSeqNo, TagAddr, ArgBytes = std::move(ArgBytes)]() {
221     using WrapperFnTy =
222         shared::detail::CWrapperFunctionResult (*)(const char *, size_t);
223     auto *Fn = TagAddr.toPtr<WrapperFnTy>();
224     shared::WrapperFunctionResult ResultBytes(
225         Fn(ArgBytes.data(), ArgBytes.size()));
226     if (auto Err = sendMessage(SimpleRemoteEPCOpcode::Result, RemoteSeqNo,
227                                ExecutorAddr(),
228                                {ResultBytes.data(), ResultBytes.size()}))
229       ReportError(std::move(Err));
230   });
231 }
232 
233 shared::WrapperFunctionResult
234 SimpleRemoteEPCServer::doJITDispatch(const void *FnTag, const char *ArgData,
235                                      size_t ArgSize) {
236   uint64_t SeqNo;
237   std::promise<shared::WrapperFunctionResult> ResultP;
238   auto ResultF = ResultP.get_future();
239   {
240     std::lock_guard<std::mutex> Lock(ServerStateMutex);
241     if (RunState != ServerRunning)
242       return shared::WrapperFunctionResult::createOutOfBandError(
243           "jit_dispatch not available (EPC server shut down)");
244 
245     SeqNo = getNextSeqNo();
246     assert(!PendingJITDispatchResults.count(SeqNo) && "SeqNo already in use");
247     PendingJITDispatchResults[SeqNo] = &ResultP;
248   }
249 
250   if (auto Err = sendMessage(SimpleRemoteEPCOpcode::CallWrapper, SeqNo,
251                              ExecutorAddr::fromPtr(FnTag), {ArgData, ArgSize}))
252     ReportError(std::move(Err));
253 
254   return ResultF.get();
255 }
256 
257 shared::detail::CWrapperFunctionResult
258 SimpleRemoteEPCServer::jitDispatchEntry(void *DispatchCtx, const void *FnTag,
259                                         const char *ArgData, size_t ArgSize) {
260   return reinterpret_cast<SimpleRemoteEPCServer *>(DispatchCtx)
261       ->doJITDispatch(FnTag, ArgData, ArgSize)
262       .release();
263 }
264 
265 } // end namespace orc
266 } // end namespace llvm
267