1 //===----------------- Server.cpp - Server Implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Offloading gRPC server for remote host.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <cmath>
14 #include <future>
15 
16 #include "Server.h"
17 #include "omptarget.h"
18 #include "openmp.grpc.pb.h"
19 #include "openmp.pb.h"
20 
21 using grpc::WriteOptions;
22 
23 extern std::promise<void> ShutdownPromise;
24 
Shutdown(ServerContext * Context,const Null * Request,I32 * Reply)25 Status RemoteOffloadImpl::Shutdown(ServerContext *Context, const Null *Request,
26                                    I32 *Reply) {
27   SERVER_DBG("Shutting down the server")
28 
29   Reply->set_number(0);
30   ShutdownPromise.set_value();
31   return Status::OK;
32 }
33 
34 Status
RegisterLib(ServerContext * Context,const TargetBinaryDescription * Description,I32 * Reply)35 RemoteOffloadImpl::RegisterLib(ServerContext *Context,
36                                const TargetBinaryDescription *Description,
37                                I32 *Reply) {
38   auto Desc = std::make_unique<__tgt_bin_desc>();
39 
40   unloadTargetBinaryDescription(Description, Desc.get(),
41                                 HostToRemoteDeviceImage);
42   PM->RTLs.RegisterLib(Desc.get());
43 
44   if (Descriptions.find((void *)Description->bin_ptr()) != Descriptions.end())
45     freeTargetBinaryDescription(
46         Descriptions[(void *)Description->bin_ptr()].get());
47   else
48     Descriptions[(void *)Description->bin_ptr()] = std::move(Desc);
49 
50   SERVER_DBG("Registered library")
51   Reply->set_number(0);
52   return Status::OK;
53 }
54 
UnregisterLib(ServerContext * Context,const Pointer * Request,I32 * Reply)55 Status RemoteOffloadImpl::UnregisterLib(ServerContext *Context,
56                                         const Pointer *Request, I32 *Reply) {
57   if (Descriptions.find((void *)Request->number()) == Descriptions.end()) {
58     Reply->set_number(1);
59     return Status::OK;
60   }
61 
62   PM->RTLs.UnregisterLib(Descriptions[(void *)Request->number()].get());
63   freeTargetBinaryDescription(Descriptions[(void *)Request->number()].get());
64   Descriptions.erase((void *)Request->number());
65 
66   SERVER_DBG("Unregistered library")
67   Reply->set_number(0);
68   return Status::OK;
69 }
70 
IsValidBinary(ServerContext * Context,const TargetDeviceImagePtr * DeviceImage,I32 * IsValid)71 Status RemoteOffloadImpl::IsValidBinary(ServerContext *Context,
72                                         const TargetDeviceImagePtr *DeviceImage,
73                                         I32 *IsValid) {
74   __tgt_device_image *Image =
75       HostToRemoteDeviceImage[(void *)DeviceImage->image_ptr()];
76 
77   IsValid->set_number(0);
78 
79   for (auto &RTL : PM->RTLs.AllRTLs)
80     if (auto Ret = RTL.is_valid_binary(Image)) {
81       IsValid->set_number(Ret);
82       break;
83     }
84 
85   SERVER_DBG("Checked if binary (%p) is valid",
86              (void *)(DeviceImage->image_ptr()))
87   return Status::OK;
88 }
89 
GetNumberOfDevices(ServerContext * Context,const Null * Null,I32 * NumberOfDevices)90 Status RemoteOffloadImpl::GetNumberOfDevices(ServerContext *Context,
91                                              const Null *Null,
92                                              I32 *NumberOfDevices) {
93   std::call_once(PM->RTLs.initFlag, &RTLsTy::LoadRTLs, &PM->RTLs);
94 
95   int32_t Devices = 0;
96   PM->RTLsMtx.lock();
97   for (auto &RTL : PM->RTLs.AllRTLs)
98     Devices += RTL.NumberOfDevices;
99   PM->RTLsMtx.unlock();
100 
101   NumberOfDevices->set_number(Devices);
102 
103   SERVER_DBG("Got number of devices")
104   return Status::OK;
105 }
106 
InitDevice(ServerContext * Context,const I32 * DeviceNum,I32 * Reply)107 Status RemoteOffloadImpl::InitDevice(ServerContext *Context,
108                                      const I32 *DeviceNum, I32 *Reply) {
109   Reply->set_number(PM->Devices[DeviceNum->number()]->RTL->init_device(
110       mapHostRTLDeviceId(DeviceNum->number())));
111 
112   SERVER_DBG("Initialized device %d", DeviceNum->number())
113   return Status::OK;
114 }
115 
InitRequires(ServerContext * Context,const I64 * RequiresFlag,I32 * Reply)116 Status RemoteOffloadImpl::InitRequires(ServerContext *Context,
117                                        const I64 *RequiresFlag, I32 *Reply) {
118   for (auto &Device : PM->Devices)
119     if (Device->RTL->init_requires)
120       Device->RTL->init_requires(RequiresFlag->number());
121   Reply->set_number(RequiresFlag->number());
122 
123   SERVER_DBG("Initialized requires for devices")
124   return Status::OK;
125 }
126 
LoadBinary(ServerContext * Context,const Binary * Binary,TargetTable * Reply)127 Status RemoteOffloadImpl::LoadBinary(ServerContext *Context,
128                                      const Binary *Binary, TargetTable *Reply) {
129   __tgt_device_image *Image =
130       HostToRemoteDeviceImage[(void *)Binary->image_ptr()];
131 
132   Table = PM->Devices[Binary->device_id()]->RTL->load_binary(
133       mapHostRTLDeviceId(Binary->device_id()), Image);
134   if (Table)
135     loadTargetTable(Table, *Reply, Image);
136 
137   SERVER_DBG("Loaded binary (%p) to device %d", (void *)Binary->image_ptr(),
138              Binary->device_id())
139   return Status::OK;
140 }
141 
IsDataExchangeable(ServerContext * Context,const DevicePair * Request,I32 * Reply)142 Status RemoteOffloadImpl::IsDataExchangeable(ServerContext *Context,
143                                              const DevicePair *Request,
144                                              I32 *Reply) {
145   Reply->set_number(-1);
146   if (PM->Devices[mapHostRTLDeviceId(Request->src_dev_id())]
147           ->RTL->is_data_exchangable)
148     Reply->set_number(PM->Devices[mapHostRTLDeviceId(Request->src_dev_id())]
149                           ->RTL->is_data_exchangable(Request->src_dev_id(),
150                                                      Request->dst_dev_id()));
151 
152   SERVER_DBG("Checked if data exchangeable between device %d and device %d",
153              Request->src_dev_id(), Request->dst_dev_id())
154   return Status::OK;
155 }
156 
DataAlloc(ServerContext * Context,const AllocData * Request,Pointer * Reply)157 Status RemoteOffloadImpl::DataAlloc(ServerContext *Context,
158                                     const AllocData *Request, Pointer *Reply) {
159   uint64_t TgtPtr =
160       (uint64_t)PM->Devices[Request->device_id()]->RTL->data_alloc(
161           mapHostRTLDeviceId(Request->device_id()), Request->size(),
162           (void *)Request->hst_ptr(), TARGET_ALLOC_DEFAULT);
163   Reply->set_number(TgtPtr);
164 
165   SERVER_DBG("Allocated at " DPxMOD "", DPxPTR((void *)TgtPtr))
166 
167   return Status::OK;
168 }
169 
DataSubmit(ServerContext * Context,ServerReader<SubmitData> * Reader,I32 * Reply)170 Status RemoteOffloadImpl::DataSubmit(ServerContext *Context,
171                                      ServerReader<SubmitData> *Reader,
172                                      I32 *Reply) {
173   SubmitData Request;
174   uint8_t *HostCopy = nullptr;
175   while (Reader->Read(&Request)) {
176     if (Request.start() == 0 && Request.size() == Request.data().size()) {
177       Reader->SendInitialMetadata();
178 
179       Reply->set_number(PM->Devices[Request.device_id()]->RTL->data_submit(
180           mapHostRTLDeviceId(Request.device_id()), (void *)Request.tgt_ptr(),
181           (void *)Request.data().data(), Request.data().size()));
182 
183       SERVER_DBG("Submitted %lu bytes async to (%p) on device %d",
184                  Request.data().size(), (void *)Request.tgt_ptr(),
185                  Request.device_id())
186 
187       return Status::OK;
188     }
189     if (!HostCopy) {
190       HostCopy = new uint8_t[Request.size()];
191       Reader->SendInitialMetadata();
192     }
193 
194     memcpy((void *)((char *)HostCopy + Request.start()), Request.data().data(),
195            Request.data().size());
196   }
197 
198   Reply->set_number(PM->Devices[Request.device_id()]->RTL->data_submit(
199       mapHostRTLDeviceId(Request.device_id()), (void *)Request.tgt_ptr(),
200       HostCopy, Request.size()));
201 
202   delete[] HostCopy;
203 
204   SERVER_DBG("Submitted %lu bytes to (%p) on device %d", Request.data().size(),
205              (void *)Request.tgt_ptr(), Request.device_id())
206 
207   return Status::OK;
208 }
209 
DataRetrieve(ServerContext * Context,const RetrieveData * Request,ServerWriter<Data> * Writer)210 Status RemoteOffloadImpl::DataRetrieve(ServerContext *Context,
211                                        const RetrieveData *Request,
212                                        ServerWriter<Data> *Writer) {
213   auto HstPtr = std::make_unique<char[]>(Request->size());
214 
215   auto Ret = PM->Devices[Request->device_id()]->RTL->data_retrieve(
216       mapHostRTLDeviceId(Request->device_id()), HstPtr.get(),
217       (void *)Request->tgt_ptr(), Request->size());
218 
219   if (Arena->SpaceAllocated() >= MaxSize)
220     Arena->Reset();
221 
222   if (Request->size() > BlockSize) {
223     uint64_t Start = 0, End = BlockSize;
224     for (auto I = 0; I < ceil((float)Request->size() / BlockSize); I++) {
225       auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get());
226 
227       Reply->set_start(Start);
228       Reply->set_size(Request->size());
229       Reply->set_data((char *)HstPtr.get() + Start, End - Start);
230       Reply->set_ret(Ret);
231 
232       if (!Writer->Write(*Reply)) {
233         CLIENT_DBG("Broken stream when submitting data")
234       }
235 
236       SERVER_DBG("Retrieved %lu-%lu/%lu bytes from (%p) on device %d", Start,
237                  End, Request->size(), (void *)Request->tgt_ptr(),
238                  mapHostRTLDeviceId(Request->device_id()))
239 
240       Start += BlockSize;
241       End += BlockSize;
242       if (End >= Request->size())
243         End = Request->size();
244     }
245   } else {
246     auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get());
247 
248     Reply->set_start(0);
249     Reply->set_size(Request->size());
250     Reply->set_data((char *)HstPtr.get(), Request->size());
251     Reply->set_ret(Ret);
252 
253     SERVER_DBG("Retrieved %lu bytes from (%p) on device %d", Request->size(),
254                (void *)Request->tgt_ptr(),
255                mapHostRTLDeviceId(Request->device_id()))
256 
257     Writer->WriteLast(*Reply, WriteOptions());
258   }
259 
260   return Status::OK;
261 }
262 
DataExchange(ServerContext * Context,const ExchangeData * Request,I32 * Reply)263 Status RemoteOffloadImpl::DataExchange(ServerContext *Context,
264                                        const ExchangeData *Request,
265                                        I32 *Reply) {
266   if (PM->Devices[Request->src_dev_id()]->RTL->data_exchange) {
267     int32_t Ret = PM->Devices[Request->src_dev_id()]->RTL->data_exchange(
268         mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(),
269         mapHostRTLDeviceId(Request->dst_dev_id()), (void *)Request->dst_ptr(),
270         Request->size());
271     Reply->set_number(Ret);
272   } else
273     Reply->set_number(-1);
274 
275   SERVER_DBG(
276       "Exchanged data asynchronously from device %d (%p) to device %d (%p) of "
277       "size %lu",
278       mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(),
279       mapHostRTLDeviceId(Request->dst_dev_id()), (void *)Request->dst_ptr(),
280       Request->size())
281   return Status::OK;
282 }
283 
DataDelete(ServerContext * Context,const DeleteData * Request,I32 * Reply)284 Status RemoteOffloadImpl::DataDelete(ServerContext *Context,
285                                      const DeleteData *Request, I32 *Reply) {
286   auto Ret = PM->Devices[Request->device_id()]->RTL->data_delete(
287       mapHostRTLDeviceId(Request->device_id()), (void *)Request->tgt_ptr());
288   Reply->set_number(Ret);
289 
290   SERVER_DBG("Deleted data from (%p) on device %d", (void *)Request->tgt_ptr(),
291              mapHostRTLDeviceId(Request->device_id()))
292   return Status::OK;
293 }
294 
RunTargetRegion(ServerContext * Context,const TargetRegion * Request,I32 * Reply)295 Status RemoteOffloadImpl::RunTargetRegion(ServerContext *Context,
296                                           const TargetRegion *Request,
297                                           I32 *Reply) {
298   std::vector<uint8_t> TgtArgs(Request->arg_num());
299   for (auto I = 0; I < Request->arg_num(); I++)
300     TgtArgs[I] = (uint64_t)Request->tgt_args()[I];
301 
302   std::vector<ptrdiff_t> TgtOffsets(Request->arg_num());
303   const auto *TgtOffsetItr = Request->tgt_offsets().begin();
304   for (auto I = 0; I < Request->arg_num(); I++, TgtOffsetItr++)
305     TgtOffsets[I] = (ptrdiff_t)*TgtOffsetItr;
306 
307   void *TgtEntryPtr = ((__tgt_offload_entry *)Request->tgt_entry_ptr())->addr;
308 
309   int32_t Ret = PM->Devices[Request->device_id()]->RTL->run_region(
310       mapHostRTLDeviceId(Request->device_id()), TgtEntryPtr,
311       (void **)TgtArgs.data(), TgtOffsets.data(), Request->arg_num());
312 
313   Reply->set_number(Ret);
314 
315   SERVER_DBG("Ran TargetRegion on device %d with %d args",
316              mapHostRTLDeviceId(Request->device_id()), Request->arg_num())
317   return Status::OK;
318 }
319 
RunTargetTeamRegion(ServerContext * Context,const TargetTeamRegion * Request,I32 * Reply)320 Status RemoteOffloadImpl::RunTargetTeamRegion(ServerContext *Context,
321                                               const TargetTeamRegion *Request,
322                                               I32 *Reply) {
323   std::vector<uint64_t> TgtArgs(Request->arg_num());
324   for (auto I = 0; I < Request->arg_num(); I++)
325     TgtArgs[I] = (uint64_t)Request->tgt_args()[I];
326 
327   std::vector<ptrdiff_t> TgtOffsets(Request->arg_num());
328   const auto *TgtOffsetItr = Request->tgt_offsets().begin();
329   for (auto I = 0; I < Request->arg_num(); I++, TgtOffsetItr++)
330     TgtOffsets[I] = (ptrdiff_t)*TgtOffsetItr;
331 
332   void *TgtEntryPtr = ((__tgt_offload_entry *)Request->tgt_entry_ptr())->addr;
333 
334   int32_t Ret = PM->Devices[Request->device_id()]->RTL->run_team_region(
335       mapHostRTLDeviceId(Request->device_id()), TgtEntryPtr,
336       (void **)TgtArgs.data(), TgtOffsets.data(), Request->arg_num(),
337       Request->team_num(), Request->thread_limit(), Request->loop_tripcount());
338 
339   Reply->set_number(Ret);
340 
341   SERVER_DBG("Ran TargetTeamRegion on device %d with %d args",
342              mapHostRTLDeviceId(Request->device_id()), Request->arg_num())
343   return Status::OK;
344 }
345 
mapHostRTLDeviceId(int32_t RTLDeviceID)346 int32_t RemoteOffloadImpl::mapHostRTLDeviceId(int32_t RTLDeviceID) {
347   for (auto &RTL : PM->RTLs.UsedRTLs) {
348     if (RTLDeviceID - RTL->NumberOfDevices >= 0)
349       RTLDeviceID -= RTL->NumberOfDevices;
350     else
351       break;
352   }
353   return RTLDeviceID;
354 }
355