1 //===----------------- Server.cpp - Server Implementation -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Offloading gRPC server for remote host. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <cmath> 14 #include <future> 15 16 #include "Server.h" 17 #include "omptarget.h" 18 #include "openmp.grpc.pb.h" 19 #include "openmp.pb.h" 20 21 using grpc::WriteOptions; 22 23 extern std::promise<void> ShutdownPromise; 24 25 Status RemoteOffloadImpl::Shutdown(ServerContext *Context, const Null *Request, 26 I32 *Reply) { 27 SERVER_DBG("Shutting down the server") 28 29 Reply->set_number(0); 30 ShutdownPromise.set_value(); 31 return Status::OK; 32 } 33 34 Status 35 RemoteOffloadImpl::RegisterLib(ServerContext *Context, 36 const TargetBinaryDescription *Description, 37 I32 *Reply) { 38 auto Desc = std::make_unique<__tgt_bin_desc>(); 39 40 unloadTargetBinaryDescription(Description, Desc.get(), 41 HostToRemoteDeviceImage); 42 PM->RTLs.RegisterLib(Desc.get()); 43 44 if (Descriptions.find((void *)Description->bin_ptr()) != Descriptions.end()) 45 freeTargetBinaryDescription( 46 Descriptions[(void *)Description->bin_ptr()].get()); 47 else 48 Descriptions[(void *)Description->bin_ptr()] = std::move(Desc); 49 50 SERVER_DBG("Registered library") 51 Reply->set_number(0); 52 return Status::OK; 53 } 54 55 Status RemoteOffloadImpl::UnregisterLib(ServerContext *Context, 56 const Pointer *Request, I32 *Reply) { 57 if (Descriptions.find((void *)Request->number()) == Descriptions.end()) { 58 Reply->set_number(1); 59 return Status::OK; 60 } 61 62 PM->RTLs.UnregisterLib(Descriptions[(void *)Request->number()].get()); 63 freeTargetBinaryDescription(Descriptions[(void *)Request->number()].get()); 64 Descriptions.erase((void *)Request->number()); 65 66 SERVER_DBG("Unregistered library") 67 Reply->set_number(0); 68 return Status::OK; 69 } 70 71 Status RemoteOffloadImpl::IsValidBinary(ServerContext *Context, 72 const TargetDeviceImagePtr *DeviceImage, 73 I32 *IsValid) { 74 __tgt_device_image *Image = 75 HostToRemoteDeviceImage[(void *)DeviceImage->image_ptr()]; 76 77 IsValid->set_number(0); 78 79 for (auto &RTL : PM->RTLs.AllRTLs) 80 if (auto Ret = RTL.is_valid_binary(Image)) { 81 IsValid->set_number(Ret); 82 break; 83 } 84 85 SERVER_DBG("Checked if binary (%p) is valid", 86 (void *)(DeviceImage->image_ptr())) 87 return Status::OK; 88 } 89 90 Status RemoteOffloadImpl::GetNumberOfDevices(ServerContext *Context, 91 const Null *Null, 92 I32 *NumberOfDevices) { 93 std::call_once(PM->RTLs.initFlag, &RTLsTy::LoadRTLs, &PM->RTLs); 94 95 int32_t Devices = 0; 96 PM->RTLsMtx.lock(); 97 for (auto &RTL : PM->RTLs.AllRTLs) 98 Devices += RTL.NumberOfDevices; 99 PM->RTLsMtx.unlock(); 100 101 NumberOfDevices->set_number(Devices); 102 103 SERVER_DBG("Got number of devices") 104 return Status::OK; 105 } 106 107 Status RemoteOffloadImpl::InitDevice(ServerContext *Context, 108 const I32 *DeviceNum, I32 *Reply) { 109 Reply->set_number(PM->Devices[DeviceNum->number()].RTL->init_device( 110 mapHostRTLDeviceId(DeviceNum->number()))); 111 112 SERVER_DBG("Initialized device %d", DeviceNum->number()) 113 return Status::OK; 114 } 115 116 Status RemoteOffloadImpl::InitRequires(ServerContext *Context, 117 const I64 *RequiresFlag, I32 *Reply) { 118 for (auto &Device : PM->Devices) 119 if (Device.RTL->init_requires) 120 Device.RTL->init_requires(RequiresFlag->number()); 121 Reply->set_number(RequiresFlag->number()); 122 123 SERVER_DBG("Initialized requires for devices") 124 return Status::OK; 125 } 126 127 Status RemoteOffloadImpl::LoadBinary(ServerContext *Context, 128 const Binary *Binary, TargetTable *Reply) { 129 __tgt_device_image *Image = 130 HostToRemoteDeviceImage[(void *)Binary->image_ptr()]; 131 132 Table = PM->Devices[Binary->device_id()].RTL->load_binary( 133 mapHostRTLDeviceId(Binary->device_id()), Image); 134 if (Table) 135 loadTargetTable(Table, *Reply, Image); 136 137 SERVER_DBG("Loaded binary (%p) to device %d", (void *)Binary->image_ptr(), 138 Binary->device_id()) 139 return Status::OK; 140 } 141 142 Status RemoteOffloadImpl::IsDataExchangeable(ServerContext *Context, 143 const DevicePair *Request, 144 I32 *Reply) { 145 Reply->set_number(-1); 146 if (PM->Devices[mapHostRTLDeviceId(Request->src_dev_id())] 147 .RTL->is_data_exchangable) 148 Reply->set_number(PM->Devices[mapHostRTLDeviceId(Request->src_dev_id())] 149 .RTL->is_data_exchangable(Request->src_dev_id(), 150 Request->dst_dev_id())); 151 152 SERVER_DBG("Checked if data exchangeable between device %d and device %d", 153 Request->src_dev_id(), Request->dst_dev_id()) 154 return Status::OK; 155 } 156 157 Status RemoteOffloadImpl::DataAlloc(ServerContext *Context, 158 const AllocData *Request, Pointer *Reply) { 159 uint64_t TgtPtr = (uint64_t)PM->Devices[Request->device_id()].RTL->data_alloc( 160 mapHostRTLDeviceId(Request->device_id()), Request->size(), 161 (void *)Request->hst_ptr(), TARGET_ALLOC_DEFAULT); 162 Reply->set_number(TgtPtr); 163 164 SERVER_DBG("Allocated at " DPxMOD "", DPxPTR((void *)TgtPtr)) 165 166 return Status::OK; 167 } 168 169 Status RemoteOffloadImpl::DataSubmit(ServerContext *Context, 170 ServerReader<SubmitData> *Reader, 171 I32 *Reply) { 172 SubmitData Request; 173 uint8_t *HostCopy = nullptr; 174 while (Reader->Read(&Request)) { 175 if (Request.start() == 0 && Request.size() == Request.data().size()) { 176 Reader->SendInitialMetadata(); 177 178 Reply->set_number(PM->Devices[Request.device_id()].RTL->data_submit( 179 mapHostRTLDeviceId(Request.device_id()), (void *)Request.tgt_ptr(), 180 (void *)Request.data().data(), Request.data().size())); 181 182 SERVER_DBG("Submitted %lu bytes async to (%p) on device %d", 183 Request.data().size(), (void *)Request.tgt_ptr(), 184 Request.device_id()) 185 186 return Status::OK; 187 } 188 if (!HostCopy) { 189 HostCopy = new uint8_t[Request.size()]; 190 Reader->SendInitialMetadata(); 191 } 192 193 memcpy((void *)((char *)HostCopy + Request.start()), Request.data().data(), 194 Request.data().size()); 195 } 196 197 Reply->set_number(PM->Devices[Request.device_id()].RTL->data_submit( 198 mapHostRTLDeviceId(Request.device_id()), (void *)Request.tgt_ptr(), 199 HostCopy, Request.size())); 200 201 delete[] HostCopy; 202 203 SERVER_DBG("Submitted %lu bytes to (%p) on device %d", Request.data().size(), 204 (void *)Request.tgt_ptr(), Request.device_id()) 205 206 return Status::OK; 207 } 208 209 Status RemoteOffloadImpl::DataRetrieve(ServerContext *Context, 210 const RetrieveData *Request, 211 ServerWriter<Data> *Writer) { 212 auto HstPtr = std::make_unique<char[]>(Request->size()); 213 214 auto Ret = PM->Devices[Request->device_id()].RTL->data_retrieve( 215 mapHostRTLDeviceId(Request->device_id()), HstPtr.get(), 216 (void *)Request->tgt_ptr(), Request->size()); 217 218 if (Arena->SpaceAllocated() >= MaxSize) 219 Arena->Reset(); 220 221 if (Request->size() > BlockSize) { 222 uint64_t Start = 0, End = BlockSize; 223 for (auto I = 0; I < ceil((float)Request->size() / BlockSize); I++) { 224 auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get()); 225 226 Reply->set_start(Start); 227 Reply->set_size(Request->size()); 228 Reply->set_data((char *)HstPtr.get() + Start, End - Start); 229 Reply->set_ret(Ret); 230 231 if (!Writer->Write(*Reply)) { 232 CLIENT_DBG("Broken stream when submitting data") 233 } 234 235 SERVER_DBG("Retrieved %lu-%lu/%lu bytes from (%p) on device %d", Start, 236 End, Request->size(), (void *)Request->tgt_ptr(), 237 mapHostRTLDeviceId(Request->device_id())) 238 239 Start += BlockSize; 240 End += BlockSize; 241 if (End >= Request->size()) 242 End = Request->size(); 243 } 244 } else { 245 auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get()); 246 247 Reply->set_start(0); 248 Reply->set_size(Request->size()); 249 Reply->set_data((char *)HstPtr.get(), Request->size()); 250 Reply->set_ret(Ret); 251 252 SERVER_DBG("Retrieved %lu bytes from (%p) on device %d", Request->size(), 253 (void *)Request->tgt_ptr(), 254 mapHostRTLDeviceId(Request->device_id())) 255 256 Writer->WriteLast(*Reply, WriteOptions()); 257 } 258 259 return Status::OK; 260 } 261 262 Status RemoteOffloadImpl::DataExchange(ServerContext *Context, 263 const ExchangeData *Request, 264 I32 *Reply) { 265 if (PM->Devices[Request->src_dev_id()].RTL->data_exchange) { 266 int32_t Ret = PM->Devices[Request->src_dev_id()].RTL->data_exchange( 267 mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(), 268 mapHostRTLDeviceId(Request->dst_dev_id()), (void *)Request->dst_ptr(), 269 Request->size()); 270 Reply->set_number(Ret); 271 } else 272 Reply->set_number(-1); 273 274 SERVER_DBG( 275 "Exchanged data asynchronously from device %d (%p) to device %d (%p) of " 276 "size %lu", 277 mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(), 278 mapHostRTLDeviceId(Request->dst_dev_id()), (void *)Request->dst_ptr(), 279 Request->size()) 280 return Status::OK; 281 } 282 283 Status RemoteOffloadImpl::DataDelete(ServerContext *Context, 284 const DeleteData *Request, I32 *Reply) { 285 auto Ret = PM->Devices[Request->device_id()].RTL->data_delete( 286 mapHostRTLDeviceId(Request->device_id()), (void *)Request->tgt_ptr()); 287 Reply->set_number(Ret); 288 289 SERVER_DBG("Deleted data from (%p) on device %d", (void *)Request->tgt_ptr(), 290 mapHostRTLDeviceId(Request->device_id())) 291 return Status::OK; 292 } 293 294 Status RemoteOffloadImpl::RunTargetRegion(ServerContext *Context, 295 const TargetRegion *Request, 296 I32 *Reply) { 297 std::vector<uint8_t> TgtArgs(Request->arg_num()); 298 for (auto I = 0; I < Request->arg_num(); I++) 299 TgtArgs[I] = (uint64_t)Request->tgt_args()[I]; 300 301 std::vector<ptrdiff_t> TgtOffsets(Request->arg_num()); 302 const auto *TgtOffsetItr = Request->tgt_offsets().begin(); 303 for (auto I = 0; I < Request->arg_num(); I++, TgtOffsetItr++) 304 TgtOffsets[I] = (ptrdiff_t)*TgtOffsetItr; 305 306 void *TgtEntryPtr = ((__tgt_offload_entry *)Request->tgt_entry_ptr())->addr; 307 308 int32_t Ret = PM->Devices[Request->device_id()].RTL->run_region( 309 mapHostRTLDeviceId(Request->device_id()), TgtEntryPtr, 310 (void **)TgtArgs.data(), TgtOffsets.data(), Request->arg_num()); 311 312 Reply->set_number(Ret); 313 314 SERVER_DBG("Ran TargetRegion on device %d with %d args", 315 mapHostRTLDeviceId(Request->device_id()), Request->arg_num()) 316 return Status::OK; 317 } 318 319 Status RemoteOffloadImpl::RunTargetTeamRegion(ServerContext *Context, 320 const TargetTeamRegion *Request, 321 I32 *Reply) { 322 std::vector<uint64_t> TgtArgs(Request->arg_num()); 323 for (auto I = 0; I < Request->arg_num(); I++) 324 TgtArgs[I] = (uint64_t)Request->tgt_args()[I]; 325 326 std::vector<ptrdiff_t> TgtOffsets(Request->arg_num()); 327 const auto *TgtOffsetItr = Request->tgt_offsets().begin(); 328 for (auto I = 0; I < Request->arg_num(); I++, TgtOffsetItr++) 329 TgtOffsets[I] = (ptrdiff_t)*TgtOffsetItr; 330 331 void *TgtEntryPtr = ((__tgt_offload_entry *)Request->tgt_entry_ptr())->addr; 332 333 int32_t Ret = PM->Devices[Request->device_id()].RTL->run_team_region( 334 mapHostRTLDeviceId(Request->device_id()), TgtEntryPtr, 335 (void **)TgtArgs.data(), TgtOffsets.data(), Request->arg_num(), 336 Request->team_num(), Request->thread_limit(), Request->loop_tripcount()); 337 338 Reply->set_number(Ret); 339 340 SERVER_DBG("Ran TargetTeamRegion on device %d with %d args", 341 mapHostRTLDeviceId(Request->device_id()), Request->arg_num()) 342 return Status::OK; 343 } 344 345 int32_t RemoteOffloadImpl::mapHostRTLDeviceId(int32_t RTLDeviceID) { 346 for (auto &RTL : PM->RTLs.UsedRTLs) { 347 if (RTLDeviceID - RTL->NumberOfDevices >= 0) 348 RTLDeviceID -= RTL->NumberOfDevices; 349 else 350 break; 351 } 352 return RTLDeviceID; 353 } 354