1 //===----------------- Server.cpp - Server Implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Offloading gRPC server for remote host.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <cmath>
14 #include <future>
15 
16 #include "Server.h"
17 #include "omptarget.h"
18 #include "openmp.grpc.pb.h"
19 #include "openmp.pb.h"
20 
21 using grpc::WriteOptions;
22 
23 extern std::promise<void> ShutdownPromise;
24 
25 Status RemoteOffloadImpl::Shutdown(ServerContext *Context, const Null *Request,
26                                    I32 *Reply) {
27   SERVER_DBG("Shutting down the server");
28 
29   Reply->set_number(0);
30   ShutdownPromise.set_value();
31   return Status::OK;
32 }
33 
34 Status
35 RemoteOffloadImpl::RegisterLib(ServerContext *Context,
36                                const TargetBinaryDescription *Description,
37                                I32 *Reply) {
38   SERVER_DBG("Registering library");
39 
40   auto Desc = std::make_unique<__tgt_bin_desc>();
41 
42   unloadTargetBinaryDescription(Description, Desc.get(),
43                                 HostToRemoteDeviceImage);
44   PM->RTLs.RegisterLib(Desc.get());
45 
46   if (Descriptions.find((void *)Description->bin_ptr()) != Descriptions.end())
47     freeTargetBinaryDescription(
48         Descriptions[(void *)Description->bin_ptr()].get());
49   else
50     Descriptions[(void *)Description->bin_ptr()] = std::move(Desc);
51 
52   SERVER_DBG("Registered library");
53   Reply->set_number(0);
54   return Status::OK;
55 }
56 
57 Status RemoteOffloadImpl::UnregisterLib(ServerContext *Context,
58                                         const Pointer *Request, I32 *Reply) {
59   SERVER_DBG("Unregistering library");
60 
61   if (Descriptions.find((void *)Request->number()) == Descriptions.end()) {
62     Reply->set_number(1);
63     return Status::OK;
64   }
65 
66   PM->RTLs.UnregisterLib(Descriptions[(void *)Request->number()].get());
67   freeTargetBinaryDescription(Descriptions[(void *)Request->number()].get());
68   Descriptions.erase((void *)Request->number());
69 
70   SERVER_DBG("Unregistered library");
71   Reply->set_number(0);
72   return Status::OK;
73 }
74 
75 Status RemoteOffloadImpl::IsValidBinary(ServerContext *Context,
76                                         const TargetDeviceImagePtr *DeviceImage,
77                                         I32 *IsValid) {
78   SERVER_DBG("Checking if binary (%p) is valid",
79              (void *)(DeviceImage->image_ptr()));
80 
81   __tgt_device_image *Image =
82       HostToRemoteDeviceImage[(void *)DeviceImage->image_ptr()];
83 
84   IsValid->set_number(0);
85 
86   for (auto &RTL : PM->RTLs.AllRTLs)
87     if (auto Ret = RTL.is_valid_binary(Image)) {
88       IsValid->set_number(Ret);
89       break;
90     }
91 
92   SERVER_DBG("Checked if binary (%p) is valid",
93              (void *)(DeviceImage->image_ptr()));
94   return Status::OK;
95 }
96 
97 Status RemoteOffloadImpl::GetNumberOfDevices(ServerContext *Context,
98                                              const Null *Null,
99                                              I32 *NumberOfDevices) {
100   SERVER_DBG("Getting number of devices");
101   std::call_once(PM->RTLs.initFlag, &RTLsTy::LoadRTLs, &PM->RTLs);
102 
103   int32_t Devices = 0;
104   PM->RTLsMtx.lock();
105   for (auto &RTL : PM->RTLs.AllRTLs)
106     Devices += RTL.NumberOfDevices;
107   PM->RTLsMtx.unlock();
108 
109   NumberOfDevices->set_number(Devices);
110 
111   SERVER_DBG("Got number of devices");
112   return Status::OK;
113 }
114 
115 Status RemoteOffloadImpl::InitDevice(ServerContext *Context,
116                                      const I32 *DeviceNum, I32 *Reply) {
117   SERVER_DBG("Initializing device %d", DeviceNum->number());
118 
119   Reply->set_number(PM->Devices[DeviceNum->number()].RTL->init_device(
120       mapHostRTLDeviceId(DeviceNum->number())));
121 
122   SERVER_DBG("Initialized device %d", DeviceNum->number());
123   return Status::OK;
124 }
125 
126 Status RemoteOffloadImpl::InitRequires(ServerContext *Context,
127                                        const I64 *RequiresFlag, I32 *Reply) {
128   SERVER_DBG("Initializing requires for devices");
129 
130   for (auto &Device : PM->Devices)
131     if (Device.RTL->init_requires)
132       Device.RTL->init_requires(RequiresFlag->number());
133   Reply->set_number(RequiresFlag->number());
134 
135   SERVER_DBG("Initialized requires for devices");
136   return Status::OK;
137 }
138 
139 Status RemoteOffloadImpl::LoadBinary(ServerContext *Context,
140                                      const Binary *Binary, TargetTable *Reply) {
141   SERVER_DBG("Loading binary (%p) to device %d", (void *)Binary->image_ptr(),
142              Binary->device_id());
143 
144   __tgt_device_image *Image =
145       HostToRemoteDeviceImage[(void *)Binary->image_ptr()];
146 
147   Table = PM->Devices[Binary->device_id()].RTL->load_binary(
148       mapHostRTLDeviceId(Binary->device_id()), Image);
149   if (Table)
150     loadTargetTable(Table, *Reply, Image);
151 
152   SERVER_DBG("Loaded binary (%p) to device %d", (void *)Binary->image_ptr(),
153              Binary->device_id());
154   return Status::OK;
155 }
156 
157 Status RemoteOffloadImpl::Synchronize(ServerContext *Context,
158                                       const SynchronizeDevice *Info,
159                                       I32 *Reply) {
160   SERVER_DBG("Synchronizing device %d (probably won't work)",
161              Info->device_id());
162 
163   void *AsyncInfo = (void *)Info->queue_ptr();
164   Reply->set_number(0);
165   if (PM->Devices[Info->device_id()].RTL->synchronize)
166     Reply->set_number(PM->Devices[Info->device_id()].synchronize(
167         (__tgt_async_info *)AsyncInfo));
168 
169   SERVER_DBG("Synchronized device %d", Info->device_id());
170   return Status::OK;
171 }
172 
173 Status RemoteOffloadImpl::IsDataExchangeable(ServerContext *Context,
174                                              const DevicePair *Request,
175                                              I32 *Reply) {
176   SERVER_DBG("Checking if data exchangable between device %d and device %d",
177              Request->src_dev_id(), Request->dst_dev_id());
178 
179   Reply->set_number(-1);
180   if (PM->Devices[mapHostRTLDeviceId(Request->src_dev_id())]
181           .RTL->is_data_exchangable)
182     Reply->set_number(PM->Devices[mapHostRTLDeviceId(Request->src_dev_id())]
183                           .RTL->is_data_exchangable(Request->src_dev_id(),
184                                                     Request->dst_dev_id()));
185 
186   SERVER_DBG("Checked if data exchangable between device %d and device %d",
187              Request->src_dev_id(), Request->dst_dev_id());
188   return Status::OK;
189 }
190 
191 Status RemoteOffloadImpl::DataAlloc(ServerContext *Context,
192                                     const AllocData *Request, Pointer *Reply) {
193   SERVER_DBG("Allocating %ld bytes on sevice %d", Request->size(),
194              Request->device_id());
195 
196   uint64_t TgtPtr = (uint64_t)PM->Devices[Request->device_id()].RTL->data_alloc(
197       mapHostRTLDeviceId(Request->device_id()), Request->size(),
198       (void *)Request->hst_ptr());
199   Reply->set_number(TgtPtr);
200 
201   SERVER_DBG("Allocated at " DPxMOD "", DPxPTR((void *)TgtPtr));
202 
203   return Status::OK;
204 }
205 
206 Status RemoteOffloadImpl::DataSubmitAsync(ServerContext *Context,
207                                           ServerReader<SubmitDataAsync> *Reader,
208                                           I32 *Reply) {
209   SubmitDataAsync Request;
210   uint8_t *HostCopy = nullptr;
211   while (Reader->Read(&Request)) {
212     if (Request.start() == 0 && Request.size() == Request.data().size()) {
213       SERVER_DBG("Submitting %lu bytes async to (%p) on device %d",
214                  Request.data().size(), (void *)Request.tgt_ptr(),
215                  Request.device_id());
216 
217       SERVER_DBG("  Host Pointer Info: %p, %p", (void *)Request.hst_ptr(),
218                  static_cast<const void *>(Request.data().data()));
219 
220       Reader->SendInitialMetadata();
221 
222       Reply->set_number(PM->Devices[Request.device_id()].RTL->data_submit(
223           mapHostRTLDeviceId(Request.device_id()), (void *)Request.tgt_ptr(),
224           (void *)Request.data().data(), Request.data().size()));
225 
226       SERVER_DBG("Submitted %lu bytes async to (%p) on device %d",
227                  Request.data().size(), (void *)Request.tgt_ptr(),
228                  Request.device_id());
229 
230       return Status::OK;
231     }
232     if (!HostCopy) {
233       HostCopy = new uint8_t[Request.size()];
234       Reader->SendInitialMetadata();
235     }
236 
237     SERVER_DBG("Submitting %lu-%lu/%lu bytes async to (%p) on device %d",
238                Request.start(), Request.start() + Request.data().size(),
239                Request.size(), (void *)Request.tgt_ptr(), Request.device_id());
240 
241     memcpy((void *)((char *)HostCopy + Request.start()), Request.data().data(),
242            Request.data().size());
243   }
244   SERVER_DBG("  Host Pointer Info: %p, %p", (void *)Request.hst_ptr(),
245              static_cast<const void *>(Request.data().data()));
246 
247   Reply->set_number(PM->Devices[Request.device_id()].RTL->data_submit(
248       mapHostRTLDeviceId(Request.device_id()), (void *)Request.tgt_ptr(),
249       HostCopy, Request.size()));
250 
251   delete[] HostCopy;
252 
253   SERVER_DBG("Submitted %lu bytes to (%p) on device %d", Request.data().size(),
254              (void *)Request.tgt_ptr(), Request.device_id());
255 
256   return Status::OK;
257 }
258 
259 Status RemoteOffloadImpl::DataRetrieveAsync(ServerContext *Context,
260                                             const RetrieveDataAsync *Request,
261                                             ServerWriter<Data> *Writer) {
262   auto HstPtr = std::make_unique<char[]>(Request->size());
263   auto Ret = PM->Devices[Request->device_id()].RTL->data_retrieve(
264       mapHostRTLDeviceId(Request->device_id()), HstPtr.get(),
265       (void *)Request->tgt_ptr(), Request->size());
266 
267   if (Arena->SpaceAllocated() >= MaxSize)
268     Arena->Reset();
269 
270   if (Request->size() > BlockSize) {
271     uint64_t Start = 0, End = BlockSize;
272     for (auto I = 0; I < ceil((float)Request->size() / BlockSize); I++) {
273       auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get());
274 
275       Reply->set_start(Start);
276       Reply->set_size(Request->size());
277       Reply->set_data((char *)HstPtr.get() + Start, End - Start);
278       Reply->set_ret(Ret);
279 
280       SERVER_DBG("Retrieving %lu-%lu/%lu bytes from (%p) on device %d", Start,
281                  End, Request->size(), (void *)Request->tgt_ptr(),
282                  mapHostRTLDeviceId(Request->device_id()));
283 
284       if (!Writer->Write(*Reply)) {
285         CLIENT_DBG("Broken stream when submitting data");
286       }
287 
288       SERVER_DBG("Retrieved %lu-%lu/%lu bytes from (%p) on device %d", Start,
289                  End, Request->size(), (void *)Request->tgt_ptr(),
290                  mapHostRTLDeviceId(Request->device_id()));
291 
292       Start += BlockSize;
293       End += BlockSize;
294       if (End >= Request->size())
295         End = Request->size();
296     }
297   } else {
298     auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get());
299 
300     SERVER_DBG("Retrieve %lu bytes from (%p) on device %d", Request->size(),
301                (void *)Request->tgt_ptr(),
302                mapHostRTLDeviceId(Request->device_id()));
303 
304     Reply->set_start(0);
305     Reply->set_size(Request->size());
306     Reply->set_data((char *)HstPtr.get(), Request->size());
307     Reply->set_ret(Ret);
308 
309     SERVER_DBG("Retrieved %lu bytes from (%p) on device %d", Request->size(),
310                (void *)Request->tgt_ptr(),
311                mapHostRTLDeviceId(Request->device_id()));
312 
313     Writer->WriteLast(*Reply, WriteOptions());
314   }
315 
316   return Status::OK;
317 }
318 
319 Status RemoteOffloadImpl::DataExchangeAsync(ServerContext *Context,
320                                             const ExchangeDataAsync *Request,
321                                             I32 *Reply) {
322   SERVER_DBG(
323       "Exchanging data asynchronously from device %d (%p) to device %d (%p) of "
324       "size %lu",
325       mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(),
326       mapHostRTLDeviceId(Request->dst_dev_id()), (void *)Request->dst_ptr(),
327       Request->size());
328 
329   if (PM->Devices[Request->src_dev_id()].RTL->data_exchange) {
330     int32_t Ret = PM->Devices[Request->src_dev_id()].RTL->data_exchange(
331         mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(),
332         mapHostRTLDeviceId(Request->dst_dev_id()), (void *)Request->dst_ptr(),
333         Request->size());
334     Reply->set_number(Ret);
335   } else
336     Reply->set_number(-1);
337 
338   SERVER_DBG(
339       "Exchanged data asynchronously from device %d (%p) to device %d (%p) of "
340       "size %lu",
341       mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(),
342       mapHostRTLDeviceId(Request->dst_dev_id()), (void *)Request->dst_ptr(),
343       Request->size());
344   return Status::OK;
345 }
346 
347 Status RemoteOffloadImpl::DataDelete(ServerContext *Context,
348                                      const DeleteData *Request, I32 *Reply) {
349   SERVER_DBG("Deleting data from (%p) on device %d", (void *)Request->tgt_ptr(),
350              mapHostRTLDeviceId(Request->device_id()));
351 
352   auto Ret = PM->Devices[Request->device_id()].RTL->data_delete(
353       mapHostRTLDeviceId(Request->device_id()), (void *)Request->tgt_ptr());
354   Reply->set_number(Ret);
355 
356   SERVER_DBG("Deleted data from (%p) on device %d", (void *)Request->tgt_ptr(),
357              mapHostRTLDeviceId(Request->device_id()));
358   return Status::OK;
359 }
360 
361 Status RemoteOffloadImpl::RunTargetRegionAsync(ServerContext *Context,
362                                                const TargetRegionAsync *Request,
363                                                I32 *Reply) {
364   SERVER_DBG("Running TargetRegionAsync on device %d with %d args",
365              mapHostRTLDeviceId(Request->device_id()), Request->arg_num());
366 
367   std::vector<uint8_t> TgtArgs(Request->arg_num());
368   for (auto I = 0; I < Request->arg_num(); I++)
369     TgtArgs[I] = (uint64_t)Request->tgt_args()[I];
370 
371   std::vector<ptrdiff_t> TgtOffsets(Request->arg_num());
372   const auto *TgtOffsetItr = Request->tgt_offsets().begin();
373   for (auto I = 0; I < Request->arg_num(); I++, TgtOffsetItr++)
374     TgtOffsets[I] = (ptrdiff_t)*TgtOffsetItr;
375 
376   void *TgtEntryPtr = ((__tgt_offload_entry *)Request->tgt_entry_ptr())->addr;
377 
378   int32_t Ret = PM->Devices[Request->device_id()].RTL->run_region(
379       mapHostRTLDeviceId(Request->device_id()), TgtEntryPtr,
380       (void **)TgtArgs.data(), TgtOffsets.data(), Request->arg_num());
381 
382   Reply->set_number(Ret);
383 
384   SERVER_DBG("Ran TargetRegionAsync on device %d with %d args",
385              mapHostRTLDeviceId(Request->device_id()), Request->arg_num());
386   return Status::OK;
387 }
388 
389 Status RemoteOffloadImpl::RunTargetTeamRegionAsync(
390     ServerContext *Context, const TargetTeamRegionAsync *Request, I32 *Reply) {
391   SERVER_DBG("Running TargetTeamRegionAsync on device %d with %d args",
392              mapHostRTLDeviceId(Request->device_id()), Request->arg_num());
393 
394   std::vector<uint64_t> TgtArgs(Request->arg_num());
395   for (auto I = 0; I < Request->arg_num(); I++)
396     TgtArgs[I] = (uint64_t)Request->tgt_args()[I];
397 
398   std::vector<ptrdiff_t> TgtOffsets(Request->arg_num());
399   const auto *TgtOffsetItr = Request->tgt_offsets().begin();
400   for (auto I = 0; I < Request->arg_num(); I++, TgtOffsetItr++)
401     TgtOffsets[I] = (ptrdiff_t)*TgtOffsetItr;
402 
403   void *TgtEntryPtr = ((__tgt_offload_entry *)Request->tgt_entry_ptr())->addr;
404   int32_t Ret = PM->Devices[Request->device_id()].RTL->run_team_region(
405       mapHostRTLDeviceId(Request->device_id()), TgtEntryPtr,
406       (void **)TgtArgs.data(), TgtOffsets.data(), Request->arg_num(),
407       Request->team_num(), Request->thread_limit(), Request->loop_tripcount());
408 
409   Reply->set_number(Ret);
410 
411   SERVER_DBG("Ran TargetTeamRegionAsync on device %d with %d args",
412              mapHostRTLDeviceId(Request->device_id()), Request->arg_num());
413   return Status::OK;
414 }
415 
416 int32_t RemoteOffloadImpl::mapHostRTLDeviceId(int32_t RTLDeviceID) {
417   for (auto &RTL : PM->RTLs.UsedRTLs) {
418     if (RTLDeviceID - RTL->NumberOfDevices >= 0)
419       RTLDeviceID -= RTL->NumberOfDevices;
420     else
421       break;
422   }
423   return RTLDeviceID;
424 }
425