1 //===----------------- Server.cpp - Server Implementation -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Offloading gRPC server for remote host. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <cmath> 14 #include <future> 15 16 #include "Server.h" 17 #include "omptarget.h" 18 #include "openmp.grpc.pb.h" 19 #include "openmp.pb.h" 20 21 using grpc::WriteOptions; 22 23 extern std::promise<void> ShutdownPromise; 24 25 Status RemoteOffloadImpl::Shutdown(ServerContext *Context, const Null *Request, 26 I32 *Reply) { 27 SERVER_DBG("Shutting down the server") 28 29 Reply->set_number(0); 30 ShutdownPromise.set_value(); 31 return Status::OK; 32 } 33 34 Status 35 RemoteOffloadImpl::RegisterLib(ServerContext *Context, 36 const TargetBinaryDescription *Description, 37 I32 *Reply) { 38 auto Desc = std::make_unique<__tgt_bin_desc>(); 39 40 unloadTargetBinaryDescription(Description, Desc.get(), 41 HostToRemoteDeviceImage); 42 PM->RTLs.RegisterLib(Desc.get()); 43 44 if (Descriptions.find((void *)Description->bin_ptr()) != Descriptions.end()) 45 freeTargetBinaryDescription( 46 Descriptions[(void *)Description->bin_ptr()].get()); 47 else 48 Descriptions[(void *)Description->bin_ptr()] = std::move(Desc); 49 50 SERVER_DBG("Registered library") 51 Reply->set_number(0); 52 return Status::OK; 53 } 54 55 Status RemoteOffloadImpl::UnregisterLib(ServerContext *Context, 56 const Pointer *Request, I32 *Reply) { 57 if (Descriptions.find((void *)Request->number()) == Descriptions.end()) { 58 Reply->set_number(1); 59 return Status::OK; 60 } 61 62 PM->RTLs.UnregisterLib(Descriptions[(void *)Request->number()].get()); 63 freeTargetBinaryDescription(Descriptions[(void *)Request->number()].get()); 64 Descriptions.erase((void *)Request->number()); 65 66 SERVER_DBG("Unregistered library") 67 Reply->set_number(0); 68 return Status::OK; 69 } 70 71 Status RemoteOffloadImpl::IsValidBinary(ServerContext *Context, 72 const TargetDeviceImagePtr *DeviceImage, 73 I32 *IsValid) { 74 __tgt_device_image *Image = 75 HostToRemoteDeviceImage[(void *)DeviceImage->image_ptr()]; 76 77 IsValid->set_number(0); 78 79 for (auto &RTL : PM->RTLs.AllRTLs) 80 if (auto Ret = RTL.is_valid_binary(Image)) { 81 IsValid->set_number(Ret); 82 break; 83 } 84 85 SERVER_DBG("Checked if binary (%p) is valid", 86 (void *)(DeviceImage->image_ptr())) 87 return Status::OK; 88 } 89 90 Status RemoteOffloadImpl::GetNumberOfDevices(ServerContext *Context, 91 const Null *Null, 92 I32 *NumberOfDevices) { 93 std::call_once(PM->RTLs.initFlag, &RTLsTy::LoadRTLs, &PM->RTLs); 94 95 int32_t Devices = 0; 96 PM->RTLsMtx.lock(); 97 for (auto &RTL : PM->RTLs.AllRTLs) 98 Devices += RTL.NumberOfDevices; 99 PM->RTLsMtx.unlock(); 100 101 NumberOfDevices->set_number(Devices); 102 103 SERVER_DBG("Got number of devices") 104 return Status::OK; 105 } 106 107 Status RemoteOffloadImpl::InitDevice(ServerContext *Context, 108 const I32 *DeviceNum, I32 *Reply) { 109 Reply->set_number(PM->Devices[DeviceNum->number()]->RTL->init_device( 110 mapHostRTLDeviceId(DeviceNum->number()))); 111 112 SERVER_DBG("Initialized device %d", DeviceNum->number()) 113 return Status::OK; 114 } 115 116 Status RemoteOffloadImpl::InitRequires(ServerContext *Context, 117 const I64 *RequiresFlag, I32 *Reply) { 118 for (auto &Device : PM->Devices) 119 if (Device->RTL->init_requires) 120 Device->RTL->init_requires(RequiresFlag->number()); 121 Reply->set_number(RequiresFlag->number()); 122 123 SERVER_DBG("Initialized requires for devices") 124 return Status::OK; 125 } 126 127 Status RemoteOffloadImpl::LoadBinary(ServerContext *Context, 128 const Binary *Binary, TargetTable *Reply) { 129 __tgt_device_image *Image = 130 HostToRemoteDeviceImage[(void *)Binary->image_ptr()]; 131 132 Table = PM->Devices[Binary->device_id()]->RTL->load_binary( 133 mapHostRTLDeviceId(Binary->device_id()), Image); 134 if (Table) 135 loadTargetTable(Table, *Reply, Image); 136 137 SERVER_DBG("Loaded binary (%p) to device %d", (void *)Binary->image_ptr(), 138 Binary->device_id()) 139 return Status::OK; 140 } 141 142 Status RemoteOffloadImpl::IsDataExchangeable(ServerContext *Context, 143 const DevicePair *Request, 144 I32 *Reply) { 145 Reply->set_number(-1); 146 if (PM->Devices[mapHostRTLDeviceId(Request->src_dev_id())] 147 ->RTL->is_data_exchangable) 148 Reply->set_number(PM->Devices[mapHostRTLDeviceId(Request->src_dev_id())] 149 ->RTL->is_data_exchangable(Request->src_dev_id(), 150 Request->dst_dev_id())); 151 152 SERVER_DBG("Checked if data exchangeable between device %d and device %d", 153 Request->src_dev_id(), Request->dst_dev_id()) 154 return Status::OK; 155 } 156 157 Status RemoteOffloadImpl::DataAlloc(ServerContext *Context, 158 const AllocData *Request, Pointer *Reply) { 159 uint64_t TgtPtr = 160 (uint64_t)PM->Devices[Request->device_id()]->RTL->data_alloc( 161 mapHostRTLDeviceId(Request->device_id()), Request->size(), 162 (void *)Request->hst_ptr(), TARGET_ALLOC_DEFAULT); 163 Reply->set_number(TgtPtr); 164 165 SERVER_DBG("Allocated at " DPxMOD "", DPxPTR((void *)TgtPtr)) 166 167 return Status::OK; 168 } 169 170 Status RemoteOffloadImpl::DataSubmit(ServerContext *Context, 171 ServerReader<SubmitData> *Reader, 172 I32 *Reply) { 173 SubmitData Request; 174 uint8_t *HostCopy = nullptr; 175 while (Reader->Read(&Request)) { 176 if (Request.start() == 0 && Request.size() == Request.data().size()) { 177 Reader->SendInitialMetadata(); 178 179 Reply->set_number(PM->Devices[Request.device_id()]->RTL->data_submit( 180 mapHostRTLDeviceId(Request.device_id()), (void *)Request.tgt_ptr(), 181 (void *)Request.data().data(), Request.data().size())); 182 183 SERVER_DBG("Submitted %lu bytes async to (%p) on device %d", 184 Request.data().size(), (void *)Request.tgt_ptr(), 185 Request.device_id()) 186 187 return Status::OK; 188 } 189 if (!HostCopy) { 190 HostCopy = new uint8_t[Request.size()]; 191 Reader->SendInitialMetadata(); 192 } 193 194 memcpy((void *)((char *)HostCopy + Request.start()), Request.data().data(), 195 Request.data().size()); 196 } 197 198 Reply->set_number(PM->Devices[Request.device_id()]->RTL->data_submit( 199 mapHostRTLDeviceId(Request.device_id()), (void *)Request.tgt_ptr(), 200 HostCopy, Request.size())); 201 202 delete[] HostCopy; 203 204 SERVER_DBG("Submitted %lu bytes to (%p) on device %d", Request.data().size(), 205 (void *)Request.tgt_ptr(), Request.device_id()) 206 207 return Status::OK; 208 } 209 210 Status RemoteOffloadImpl::DataRetrieve(ServerContext *Context, 211 const RetrieveData *Request, 212 ServerWriter<Data> *Writer) { 213 auto HstPtr = std::make_unique<char[]>(Request->size()); 214 215 auto Ret = PM->Devices[Request->device_id()]->RTL->data_retrieve( 216 mapHostRTLDeviceId(Request->device_id()), HstPtr.get(), 217 (void *)Request->tgt_ptr(), Request->size()); 218 219 if (Arena->SpaceAllocated() >= MaxSize) 220 Arena->Reset(); 221 222 if (Request->size() > BlockSize) { 223 uint64_t Start = 0, End = BlockSize; 224 for (auto I = 0; I < ceil((float)Request->size() / BlockSize); I++) { 225 auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get()); 226 227 Reply->set_start(Start); 228 Reply->set_size(Request->size()); 229 Reply->set_data((char *)HstPtr.get() + Start, End - Start); 230 Reply->set_ret(Ret); 231 232 if (!Writer->Write(*Reply)) { 233 CLIENT_DBG("Broken stream when submitting data") 234 } 235 236 SERVER_DBG("Retrieved %lu-%lu/%lu bytes from (%p) on device %d", Start, 237 End, Request->size(), (void *)Request->tgt_ptr(), 238 mapHostRTLDeviceId(Request->device_id())) 239 240 Start += BlockSize; 241 End += BlockSize; 242 if (End >= Request->size()) 243 End = Request->size(); 244 } 245 } else { 246 auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get()); 247 248 Reply->set_start(0); 249 Reply->set_size(Request->size()); 250 Reply->set_data((char *)HstPtr.get(), Request->size()); 251 Reply->set_ret(Ret); 252 253 SERVER_DBG("Retrieved %lu bytes from (%p) on device %d", Request->size(), 254 (void *)Request->tgt_ptr(), 255 mapHostRTLDeviceId(Request->device_id())) 256 257 Writer->WriteLast(*Reply, WriteOptions()); 258 } 259 260 return Status::OK; 261 } 262 263 Status RemoteOffloadImpl::DataExchange(ServerContext *Context, 264 const ExchangeData *Request, 265 I32 *Reply) { 266 if (PM->Devices[Request->src_dev_id()]->RTL->data_exchange) { 267 int32_t Ret = PM->Devices[Request->src_dev_id()]->RTL->data_exchange( 268 mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(), 269 mapHostRTLDeviceId(Request->dst_dev_id()), (void *)Request->dst_ptr(), 270 Request->size()); 271 Reply->set_number(Ret); 272 } else 273 Reply->set_number(-1); 274 275 SERVER_DBG( 276 "Exchanged data asynchronously from device %d (%p) to device %d (%p) of " 277 "size %lu", 278 mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(), 279 mapHostRTLDeviceId(Request->dst_dev_id()), (void *)Request->dst_ptr(), 280 Request->size()) 281 return Status::OK; 282 } 283 284 Status RemoteOffloadImpl::DataDelete(ServerContext *Context, 285 const DeleteData *Request, I32 *Reply) { 286 auto Ret = PM->Devices[Request->device_id()]->RTL->data_delete( 287 mapHostRTLDeviceId(Request->device_id()), (void *)Request->tgt_ptr()); 288 Reply->set_number(Ret); 289 290 SERVER_DBG("Deleted data from (%p) on device %d", (void *)Request->tgt_ptr(), 291 mapHostRTLDeviceId(Request->device_id())) 292 return Status::OK; 293 } 294 295 Status RemoteOffloadImpl::RunTargetRegion(ServerContext *Context, 296 const TargetRegion *Request, 297 I32 *Reply) { 298 std::vector<uint8_t> TgtArgs(Request->arg_num()); 299 for (auto I = 0; I < Request->arg_num(); I++) 300 TgtArgs[I] = (uint64_t)Request->tgt_args()[I]; 301 302 std::vector<ptrdiff_t> TgtOffsets(Request->arg_num()); 303 const auto *TgtOffsetItr = Request->tgt_offsets().begin(); 304 for (auto I = 0; I < Request->arg_num(); I++, TgtOffsetItr++) 305 TgtOffsets[I] = (ptrdiff_t)*TgtOffsetItr; 306 307 void *TgtEntryPtr = ((__tgt_offload_entry *)Request->tgt_entry_ptr())->addr; 308 309 int32_t Ret = PM->Devices[Request->device_id()]->RTL->run_region( 310 mapHostRTLDeviceId(Request->device_id()), TgtEntryPtr, 311 (void **)TgtArgs.data(), TgtOffsets.data(), Request->arg_num()); 312 313 Reply->set_number(Ret); 314 315 SERVER_DBG("Ran TargetRegion on device %d with %d args", 316 mapHostRTLDeviceId(Request->device_id()), Request->arg_num()) 317 return Status::OK; 318 } 319 320 Status RemoteOffloadImpl::RunTargetTeamRegion(ServerContext *Context, 321 const TargetTeamRegion *Request, 322 I32 *Reply) { 323 std::vector<uint64_t> TgtArgs(Request->arg_num()); 324 for (auto I = 0; I < Request->arg_num(); I++) 325 TgtArgs[I] = (uint64_t)Request->tgt_args()[I]; 326 327 std::vector<ptrdiff_t> TgtOffsets(Request->arg_num()); 328 const auto *TgtOffsetItr = Request->tgt_offsets().begin(); 329 for (auto I = 0; I < Request->arg_num(); I++, TgtOffsetItr++) 330 TgtOffsets[I] = (ptrdiff_t)*TgtOffsetItr; 331 332 void *TgtEntryPtr = ((__tgt_offload_entry *)Request->tgt_entry_ptr())->addr; 333 334 int32_t Ret = PM->Devices[Request->device_id()]->RTL->run_team_region( 335 mapHostRTLDeviceId(Request->device_id()), TgtEntryPtr, 336 (void **)TgtArgs.data(), TgtOffsets.data(), Request->arg_num(), 337 Request->team_num(), Request->thread_limit(), Request->loop_tripcount()); 338 339 Reply->set_number(Ret); 340 341 SERVER_DBG("Ran TargetTeamRegion on device %d with %d args", 342 mapHostRTLDeviceId(Request->device_id()), Request->arg_num()) 343 return Status::OK; 344 } 345 346 int32_t RemoteOffloadImpl::mapHostRTLDeviceId(int32_t RTLDeviceID) { 347 for (auto &RTL : PM->RTLs.UsedRTLs) { 348 if (RTLDeviceID - RTL->NumberOfDevices >= 0) 349 RTLDeviceID -= RTL->NumberOfDevices; 350 else 351 break; 352 } 353 return RTLDeviceID; 354 } 355