1 //===----------------- Client.cpp - Client Implementation -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // gRPC (Client) for the remote plugin. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <cmath> 14 15 #include "Client.h" 16 #include "omptarget.h" 17 #include "openmp.pb.h" 18 19 using namespace std::chrono; 20 21 using grpc::ClientContext; 22 using grpc::ClientReader; 23 using grpc::ClientWriter; 24 using grpc::Status; 25 26 template <typename Fn1, typename Fn2, typename TReturn> 27 auto RemoteOffloadClient::remoteCall(Fn1 Preprocess, Fn2 Postprocess, 28 TReturn ErrorValue, bool Timeout) { 29 ArenaAllocatorLock->lock(); 30 if (Arena->SpaceAllocated() >= MaxSize) 31 Arena->Reset(); 32 ArenaAllocatorLock->unlock(); 33 34 ClientContext Context; 35 if (Timeout) { 36 auto Deadline = 37 std::chrono::system_clock::now() + std::chrono::seconds(Timeout); 38 Context.set_deadline(Deadline); 39 } 40 41 Status RPCStatus; 42 auto Reply = Preprocess(RPCStatus, Context); 43 44 // TODO: Error handle more appropriately 45 if (!RPCStatus.ok()) { 46 CLIENT_DBG("%s", RPCStatus.error_message().c_str()); 47 } else { 48 return Postprocess(Reply); 49 } 50 51 CLIENT_DBG("Failed"); 52 return ErrorValue; 53 } 54 55 int32_t RemoteOffloadClient::shutdown(void) { 56 ClientContext Context; 57 Null Request; 58 I32 Reply; 59 CLIENT_DBG("Shutting down server."); 60 auto Status = Stub->Shutdown(&Context, Request, &Reply); 61 if (Status.ok()) 62 return Reply.number(); 63 return 1; 64 } 65 66 int32_t RemoteOffloadClient::registerLib(__tgt_bin_desc *Desc) { 67 return remoteCall( 68 /* Preprocess */ 69 [&](auto &RPCStatus, auto &Context) { 70 auto *Request = protobuf::Arena::CreateMessage<TargetBinaryDescription>( 71 Arena.get()); 72 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); 73 loadTargetBinaryDescription(Desc, *Request); 74 Request->set_bin_ptr((uint64_t)Desc); 75 76 CLIENT_DBG("Registering library"); 77 RPCStatus = Stub->RegisterLib(&Context, *Request, Reply); 78 return Reply; 79 }, 80 /* Postprocess */ 81 [&](const auto &Reply) { 82 if (Reply->number() == 0) { 83 CLIENT_DBG("Registered library"); 84 return 0; 85 } 86 return 1; 87 }, 88 /* Error Value */ 1); 89 } 90 91 int32_t RemoteOffloadClient::unregisterLib(__tgt_bin_desc *Desc) { 92 return remoteCall( 93 /* Preprocess */ 94 [&](auto &RPCStatus, auto &Context) { 95 auto *Request = protobuf::Arena::CreateMessage<Pointer>(Arena.get()); 96 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); 97 98 Request->set_number((uint64_t)Desc); 99 100 CLIENT_DBG("Unregistering library"); 101 RPCStatus = Stub->UnregisterLib(&Context, *Request, Reply); 102 return Reply; 103 }, 104 /* Postprocess */ 105 [&](const auto &Reply) { 106 if (Reply->number() == 0) { 107 CLIENT_DBG("Unregistered library"); 108 return 0; 109 } 110 CLIENT_DBG("Failed to unregister library"); 111 return 1; 112 }, 113 /* Error Value */ 1); 114 } 115 116 int32_t RemoteOffloadClient::isValidBinary(__tgt_device_image *Image) { 117 return remoteCall( 118 /* Preprocess */ 119 [&](auto &RPCStatus, auto &Context) { 120 auto *Request = 121 protobuf::Arena::CreateMessage<TargetDeviceImagePtr>(Arena.get()); 122 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); 123 124 Request->set_image_ptr((uint64_t)Image->ImageStart); 125 126 auto *EntryItr = Image->EntriesBegin; 127 while (EntryItr != Image->EntriesEnd) 128 Request->add_entry_ptrs((uint64_t)EntryItr++); 129 130 CLIENT_DBG("Validating binary"); 131 RPCStatus = Stub->IsValidBinary(&Context, *Request, Reply); 132 return Reply; 133 }, 134 /* Postprocess */ 135 [&](const auto &Reply) { 136 if (Reply->number()) { 137 CLIENT_DBG("Validated binary"); 138 } else { 139 CLIENT_DBG("Could not validate binary"); 140 } 141 return Reply->number(); 142 }, 143 /* Error Value */ 0); 144 } 145 146 int32_t RemoteOffloadClient::getNumberOfDevices() { 147 return remoteCall( 148 /* Preprocess */ 149 [&](Status &RPCStatus, ClientContext &Context) { 150 auto *Request = protobuf::Arena::CreateMessage<Null>(Arena.get()); 151 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); 152 153 CLIENT_DBG("Getting number of devices"); 154 RPCStatus = Stub->GetNumberOfDevices(&Context, *Request, Reply); 155 156 return Reply; 157 }, 158 /* Postprocess */ 159 [&](const auto &Reply) { 160 if (Reply->number()) { 161 CLIENT_DBG("Found %d devices", Reply->number()); 162 } else { 163 CLIENT_DBG("Could not get the number of devices"); 164 } 165 return Reply->number(); 166 }, 167 /*Error Value*/ -1); 168 } 169 170 int32_t RemoteOffloadClient::initDevice(int32_t DeviceId) { 171 return remoteCall( 172 /* Preprocess */ 173 [&](auto &RPCStatus, auto &Context) { 174 auto *Request = protobuf::Arena::CreateMessage<I32>(Arena.get()); 175 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); 176 177 Request->set_number(DeviceId); 178 179 CLIENT_DBG("Initializing device %d", DeviceId); 180 RPCStatus = Stub->InitDevice(&Context, *Request, Reply); 181 182 return Reply; 183 }, 184 /* Postprocess */ 185 [&](const auto &Reply) { 186 if (!Reply->number()) { 187 CLIENT_DBG("Initialized device %d", DeviceId); 188 } else { 189 CLIENT_DBG("Could not initialize device %d", DeviceId); 190 } 191 return Reply->number(); 192 }, 193 /* Error Value */ -1); 194 } 195 196 int32_t RemoteOffloadClient::initRequires(int64_t RequiresFlags) { 197 return remoteCall( 198 /* Preprocess */ 199 [&](auto &RPCStatus, auto &Context) { 200 auto *Request = protobuf::Arena::CreateMessage<I64>(Arena.get()); 201 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); 202 Request->set_number(RequiresFlags); 203 CLIENT_DBG("Initializing requires"); 204 RPCStatus = Stub->InitRequires(&Context, *Request, Reply); 205 return Reply; 206 }, 207 /* Postprocess */ 208 [&](const auto &Reply) { 209 if (Reply->number()) { 210 CLIENT_DBG("Initialized requires"); 211 } else { 212 CLIENT_DBG("Could not initialize requires"); 213 } 214 return Reply->number(); 215 }, 216 /* Error Value */ -1); 217 } 218 219 __tgt_target_table *RemoteOffloadClient::loadBinary(int32_t DeviceId, 220 __tgt_device_image *Image) { 221 return remoteCall( 222 /* Preprocess */ 223 [&](auto &RPCStatus, auto &Context) { 224 auto *ImageMessage = 225 protobuf::Arena::CreateMessage<Binary>(Arena.get()); 226 auto *Reply = protobuf::Arena::CreateMessage<TargetTable>(Arena.get()); 227 ImageMessage->set_image_ptr((uint64_t)Image->ImageStart); 228 ImageMessage->set_device_id(DeviceId); 229 230 CLIENT_DBG("Loading Image %p to device %d", Image, DeviceId); 231 RPCStatus = Stub->LoadBinary(&Context, *ImageMessage, Reply); 232 return Reply; 233 }, 234 /* Postprocess */ 235 [&](auto &Reply) { 236 if (Reply->entries_size() == 0) { 237 CLIENT_DBG("Could not load image %p onto device %d", Image, DeviceId); 238 return (__tgt_target_table *)nullptr; 239 } 240 DevicesToTables[DeviceId] = std::make_unique<__tgt_target_table>(); 241 unloadTargetTable(*Reply, DevicesToTables[DeviceId].get(), 242 RemoteEntries[DeviceId]); 243 244 CLIENT_DBG("Loaded Image %p to device %d with %d entries", Image, 245 DeviceId, Reply->entries_size()); 246 247 return DevicesToTables[DeviceId].get(); 248 }, 249 /* Error Value */ (__tgt_target_table *)nullptr, 250 /* Timeout */ false); 251 } 252 253 int64_t RemoteOffloadClient::synchronize(int32_t DeviceId, 254 __tgt_async_info *AsyncInfo) { 255 return remoteCall( 256 /* Preprocess */ 257 [&](auto &RPCStatus, auto &Context) { 258 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); 259 auto *Info = 260 protobuf::Arena::CreateMessage<SynchronizeDevice>(Arena.get()); 261 262 Info->set_device_id(DeviceId); 263 Info->set_queue_ptr((uint64_t)AsyncInfo); 264 265 CLIENT_DBG("Synchronizing device %d", DeviceId); 266 RPCStatus = Stub->Synchronize(&Context, *Info, Reply); 267 return Reply; 268 }, 269 /* Postprocess */ 270 [&](auto &Reply) { 271 if (Reply->number()) { 272 CLIENT_DBG("Synchronized device %d", DeviceId); 273 } else { 274 CLIENT_DBG("Could not synchronize device %d", DeviceId); 275 } 276 return Reply->number(); 277 }, 278 /* Error Value */ -1); 279 } 280 281 int32_t RemoteOffloadClient::isDataExchangeable(int32_t SrcDevId, 282 int32_t DstDevId) { 283 return remoteCall( 284 /* Preprocess */ 285 [&](auto &RPCStatus, auto &Context) { 286 auto *Request = protobuf::Arena::CreateMessage<DevicePair>(Arena.get()); 287 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); 288 289 Request->set_src_dev_id(SrcDevId); 290 Request->set_dst_dev_id(DstDevId); 291 292 CLIENT_DBG("Asking if data is exchangeable between %d, %d", SrcDevId, 293 DstDevId); 294 RPCStatus = Stub->IsDataExchangeable(&Context, *Request, Reply); 295 return Reply; 296 }, 297 /* Postprocess */ 298 [&](auto &Reply) { 299 if (Reply->number()) { 300 CLIENT_DBG("Data is exchangeable between %d, %d", SrcDevId, DstDevId); 301 } else { 302 CLIENT_DBG("Data is not exchangeable between %d, %d", SrcDevId, 303 DstDevId); 304 } 305 return Reply->number(); 306 }, 307 /* Error Value */ -1); 308 } 309 310 void *RemoteOffloadClient::dataAlloc(int32_t DeviceId, int64_t Size, 311 void *HstPtr) { 312 return remoteCall( 313 /* Preprocess */ 314 [&](auto &RPCStatus, auto &Context) { 315 auto *Reply = protobuf::Arena::CreateMessage<Pointer>(Arena.get()); 316 auto *Request = protobuf::Arena::CreateMessage<AllocData>(Arena.get()); 317 318 Request->set_device_id(DeviceId); 319 Request->set_size(Size); 320 Request->set_hst_ptr((uint64_t)HstPtr); 321 322 CLIENT_DBG("Allocating %ld bytes on device %d", Size, DeviceId); 323 RPCStatus = Stub->DataAlloc(&Context, *Request, Reply); 324 return Reply; 325 }, 326 /* Postprocess */ 327 [&](auto &Reply) { 328 if (Reply->number()) { 329 CLIENT_DBG("Allocated %ld bytes on device %d at %p", Size, DeviceId, 330 (void *)Reply->number()); 331 } else { 332 CLIENT_DBG("Could not allocate %ld bytes on device %d at %p", Size, 333 DeviceId, (void *)Reply->number()); 334 } 335 return (void *)Reply->number(); 336 }, 337 /* Error Value */ (void *)nullptr); 338 } 339 340 int32_t RemoteOffloadClient::dataSubmitAsync(int32_t DeviceId, void *TgtPtr, 341 void *HstPtr, int64_t Size, 342 __tgt_async_info *AsyncInfo) { 343 344 return remoteCall( 345 /* Preprocess */ 346 [&](auto &RPCStatus, auto &Context) { 347 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); 348 std::unique_ptr<ClientWriter<SubmitDataAsync>> Writer( 349 Stub->DataSubmitAsync(&Context, Reply)); 350 351 if (Size > BlockSize) { 352 int64_t Start = 0, End = BlockSize; 353 for (auto I = 0; I < ceil((float)Size / BlockSize); I++) { 354 auto *Request = 355 protobuf::Arena::CreateMessage<SubmitDataAsync>(Arena.get()); 356 357 Request->set_device_id(DeviceId); 358 Request->set_data((char *)HstPtr + Start, End - Start); 359 Request->set_hst_ptr((uint64_t)HstPtr); 360 Request->set_tgt_ptr((uint64_t)TgtPtr); 361 Request->set_start(Start); 362 Request->set_size(Size); 363 Request->set_queue_ptr((uint64_t)AsyncInfo); 364 365 CLIENT_DBG("Submitting %ld-%ld/%ld bytes async on device %d at %p", 366 Start, End, Size, DeviceId, TgtPtr) 367 368 if (!Writer->Write(*Request)) { 369 CLIENT_DBG("Broken stream when submitting data"); 370 Reply->set_number(0); 371 return Reply; 372 } 373 374 Start += BlockSize; 375 End += BlockSize; 376 if (End >= Size) 377 End = Size; 378 } 379 } else { 380 auto *Request = 381 protobuf::Arena::CreateMessage<SubmitDataAsync>(Arena.get()); 382 383 Request->set_device_id(DeviceId); 384 Request->set_data(HstPtr, Size); 385 Request->set_hst_ptr((uint64_t)HstPtr); 386 Request->set_tgt_ptr((uint64_t)TgtPtr); 387 Request->set_start(0); 388 Request->set_size(Size); 389 390 CLIENT_DBG("Submitting %ld bytes async on device %d at %p", Size, 391 DeviceId, TgtPtr) 392 if (!Writer->Write(*Request)) { 393 CLIENT_DBG("Broken stream when submitting data"); 394 Reply->set_number(0); 395 return Reply; 396 } 397 } 398 399 Writer->WritesDone(); 400 RPCStatus = Writer->Finish(); 401 402 return Reply; 403 }, 404 /* Postprocess */ 405 [&](auto &Reply) { 406 if (!Reply->number()) { 407 CLIENT_DBG("Async submitted %ld bytes on device %d at %p", Size, 408 DeviceId, TgtPtr) 409 } else { 410 CLIENT_DBG("Could not async submit %ld bytes on device %d at %p", 411 Size, DeviceId, TgtPtr) 412 } 413 return Reply->number(); 414 }, 415 /* Error Value */ -1, 416 /* Timeout */ false); 417 } 418 419 int32_t RemoteOffloadClient::dataRetrieveAsync(int32_t DeviceId, void *HstPtr, 420 void *TgtPtr, int64_t Size, 421 __tgt_async_info *AsyncInfo) { 422 return remoteCall( 423 /* Preprocess */ 424 [&](auto &RPCStatus, auto &Context) { 425 auto *Request = 426 protobuf::Arena::CreateMessage<RetrieveDataAsync>(Arena.get()); 427 428 Request->set_device_id(DeviceId); 429 Request->set_size(Size); 430 Request->set_hst_ptr((int64_t)HstPtr); 431 Request->set_tgt_ptr((int64_t)TgtPtr); 432 Request->set_queue_ptr((uint64_t)AsyncInfo); 433 434 auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get()); 435 std::unique_ptr<ClientReader<Data>> Reader( 436 Stub->DataRetrieveAsync(&Context, *Request)); 437 Reader->WaitForInitialMetadata(); 438 while (Reader->Read(Reply)) { 439 if (Reply->ret()) { 440 CLIENT_DBG("Could not async retrieve %ld bytes on device %d at %p " 441 "for %p", 442 Size, DeviceId, TgtPtr, HstPtr) 443 return Reply; 444 } 445 446 if (Reply->start() == 0 && Reply->size() == Reply->data().size()) { 447 CLIENT_DBG("Async retrieving %ld bytes on device %d at %p for %p", 448 Size, DeviceId, TgtPtr, HstPtr) 449 450 memcpy(HstPtr, Reply->data().data(), Reply->data().size()); 451 452 return Reply; 453 } 454 CLIENT_DBG("Retrieving %lu-%lu/%lu bytes async from (%p) to (%p) " 455 "on Device %d", 456 Reply->start(), Reply->start() + Reply->data().size(), 457 Reply->size(), (void *)Request->tgt_ptr(), HstPtr, 458 Request->device_id()); 459 460 memcpy((void *)((char *)HstPtr + Reply->start()), 461 Reply->data().data(), Reply->data().size()); 462 } 463 RPCStatus = Reader->Finish(); 464 465 return Reply; 466 }, 467 /* Postprocess */ 468 [&](auto &Reply) { 469 if (!Reply->ret()) { 470 CLIENT_DBG("Async retrieve %ld bytes on Device %d", Size, DeviceId); 471 } else { 472 CLIENT_DBG("Could not async retrieve %ld bytes on Device %d", Size, 473 DeviceId); 474 } 475 return Reply->ret(); 476 }, 477 /* Error Value */ -1, 478 /* Timeout */ false); 479 } 480 481 int32_t RemoteOffloadClient::dataExchangeAsync(int32_t SrcDevId, void *SrcPtr, 482 int32_t DstDevId, void *DstPtr, 483 int64_t Size, 484 __tgt_async_info *AsyncInfo) { 485 return remoteCall( 486 /* Preprocess */ 487 [&](auto &RPCStatus, auto &Context) { 488 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); 489 auto *Request = 490 protobuf::Arena::CreateMessage<ExchangeDataAsync>(Arena.get()); 491 492 Request->set_src_dev_id(SrcDevId); 493 Request->set_src_ptr((uint64_t)SrcPtr); 494 Request->set_dst_dev_id(DstDevId); 495 Request->set_dst_ptr((uint64_t)DstPtr); 496 Request->set_size(Size); 497 Request->set_queue_ptr((uint64_t)AsyncInfo); 498 499 CLIENT_DBG( 500 "Exchanging %ld bytes on device %d at %p for %p on device %d", Size, 501 SrcDevId, SrcPtr, DstPtr, DstDevId); 502 RPCStatus = Stub->DataExchangeAsync(&Context, *Request, Reply); 503 return Reply; 504 }, 505 /* Postprocess */ 506 [&](auto &Reply) { 507 if (Reply->number()) { 508 CLIENT_DBG( 509 "Exchanged %ld bytes on device %d at %p for %p on device %d", 510 Size, SrcDevId, SrcPtr, DstPtr, DstDevId); 511 } else { 512 CLIENT_DBG("Could not exchange %ld bytes on device %d at %p for %p " 513 "on device %d", 514 Size, SrcDevId, SrcPtr, DstPtr, DstDevId); 515 } 516 return Reply->number(); 517 }, 518 /* Error Value */ -1); 519 } 520 521 int32_t RemoteOffloadClient::dataDelete(int32_t DeviceId, void *TgtPtr) { 522 return remoteCall( 523 /* Preprocess */ 524 [&](auto &RPCStatus, auto &Context) { 525 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); 526 auto *Request = protobuf::Arena::CreateMessage<DeleteData>(Arena.get()); 527 528 Request->set_device_id(DeviceId); 529 Request->set_tgt_ptr((uint64_t)TgtPtr); 530 531 CLIENT_DBG("Deleting data at %p on device %d", TgtPtr, DeviceId) 532 RPCStatus = Stub->DataDelete(&Context, *Request, Reply); 533 return Reply; 534 }, 535 /* Postprocess */ 536 [&](auto &Reply) { 537 if (!Reply->number()) { 538 CLIENT_DBG("Deleted data at %p on device %d", TgtPtr, DeviceId) 539 } else { 540 CLIENT_DBG("Could not delete data at %p on device %d", TgtPtr, 541 DeviceId) 542 } 543 return Reply->number(); 544 }, 545 /* Error Value */ -1); 546 } 547 548 int32_t RemoteOffloadClient::runTargetRegionAsync( 549 int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets, 550 int32_t ArgNum, __tgt_async_info *AsyncInfo) { 551 return remoteCall( 552 /* Preprocess */ 553 [&](auto &RPCStatus, auto &Context) { 554 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); 555 auto *Request = 556 protobuf::Arena::CreateMessage<TargetRegionAsync>(Arena.get()); 557 558 Request->set_device_id(DeviceId); 559 Request->set_queue_ptr((uint64_t)AsyncInfo); 560 561 Request->set_tgt_entry_ptr( 562 (uint64_t)RemoteEntries[DeviceId][TgtEntryPtr]); 563 564 char **ArgPtr = (char **)TgtArgs; 565 for (auto I = 0; I < ArgNum; I++, ArgPtr++) 566 Request->add_tgt_args((uint64_t)*ArgPtr); 567 568 char *OffsetPtr = (char *)TgtOffsets; 569 for (auto I = 0; I < ArgNum; I++, OffsetPtr++) 570 Request->add_tgt_offsets((uint64_t)*OffsetPtr); 571 572 Request->set_arg_num(ArgNum); 573 574 CLIENT_DBG("Running target region async on device %d", DeviceId); 575 RPCStatus = Stub->RunTargetRegionAsync(&Context, *Request, Reply); 576 return Reply; 577 }, 578 /* Postprocess */ 579 [&](auto &Reply) { 580 if (!Reply->number()) { 581 CLIENT_DBG("Ran target region async on device %d", DeviceId); 582 } else { 583 CLIENT_DBG("Could not run target region async on device %d", 584 DeviceId); 585 } 586 return Reply->number(); 587 }, 588 /* Error Value */ -1, 589 /* Timeout */ false); 590 } 591 592 int32_t RemoteOffloadClient::runTargetTeamRegionAsync( 593 int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets, 594 int32_t ArgNum, int32_t TeamNum, int32_t ThreadLimit, 595 uint64_t LoopTripcount, __tgt_async_info *AsyncInfo) { 596 return remoteCall( 597 /* Preprocess */ 598 [&](auto &RPCStatus, auto &Context) { 599 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); 600 auto *Request = 601 protobuf::Arena::CreateMessage<TargetTeamRegionAsync>(Arena.get()); 602 603 Request->set_device_id(DeviceId); 604 Request->set_queue_ptr((uint64_t)AsyncInfo); 605 606 Request->set_tgt_entry_ptr( 607 (uint64_t)RemoteEntries[DeviceId][TgtEntryPtr]); 608 609 char **ArgPtr = (char **)TgtArgs; 610 for (auto I = 0; I < ArgNum; I++, ArgPtr++) { 611 Request->add_tgt_args((uint64_t)*ArgPtr); 612 } 613 614 char *OffsetPtr = (char *)TgtOffsets; 615 for (auto I = 0; I < ArgNum; I++, OffsetPtr++) 616 Request->add_tgt_offsets((uint64_t)*OffsetPtr); 617 618 Request->set_arg_num(ArgNum); 619 Request->set_team_num(TeamNum); 620 Request->set_thread_limit(ThreadLimit); 621 Request->set_loop_tripcount(LoopTripcount); 622 623 CLIENT_DBG("Running target team region async on device %d", DeviceId); 624 RPCStatus = Stub->RunTargetTeamRegionAsync(&Context, *Request, Reply); 625 return Reply; 626 }, 627 /* Postprocess */ 628 [&](auto &Reply) { 629 if (!Reply->number()) { 630 CLIENT_DBG("Ran target team region async on device %d", DeviceId); 631 } else { 632 CLIENT_DBG("Could not run target team region async on device %d", 633 DeviceId); 634 } 635 return Reply->number(); 636 }, 637 /* Error Value */ -1, 638 /* Timeout */ false); 639 } 640 641 // TODO: Better error handling for the next three functions 642 int32_t RemoteClientManager::shutdown(void) { 643 int32_t Ret = 0; 644 for (auto &Client : Clients) 645 Ret &= Client.shutdown(); 646 return Ret; 647 } 648 649 int32_t RemoteClientManager::registerLib(__tgt_bin_desc *Desc) { 650 int32_t Ret = 0; 651 for (auto &Client : Clients) 652 Ret &= Client.registerLib(Desc); 653 return Ret; 654 } 655 656 int32_t RemoteClientManager::unregisterLib(__tgt_bin_desc *Desc) { 657 int32_t Ret = 0; 658 for (auto &Client : Clients) 659 Ret &= Client.unregisterLib(Desc); 660 return Ret; 661 } 662 663 int32_t RemoteClientManager::isValidBinary(__tgt_device_image *Image) { 664 int32_t ClientIdx = 0; 665 for (auto &Client : Clients) { 666 if (auto Ret = Client.isValidBinary(Image)) 667 return Ret; 668 ClientIdx++; 669 } 670 return 0; 671 } 672 673 int32_t RemoteClientManager::getNumberOfDevices() { 674 auto ClientIdx = 0; 675 for (auto &Client : Clients) { 676 if (auto NumDevices = Client.getNumberOfDevices()) { 677 Devices.push_back(NumDevices); 678 } 679 ClientIdx++; 680 } 681 682 return std::accumulate(Devices.begin(), Devices.end(), 0); 683 } 684 685 std::pair<int32_t, int32_t> RemoteClientManager::mapDeviceId(int32_t DeviceId) { 686 for (size_t ClientIdx = 0; ClientIdx < Devices.size(); ClientIdx++) { 687 if (!(DeviceId >= Devices[ClientIdx])) 688 return {ClientIdx, DeviceId}; 689 DeviceId -= Devices[ClientIdx]; 690 } 691 return {-1, -1}; 692 } 693 694 int32_t RemoteClientManager::initDevice(int32_t DeviceId) { 695 int32_t ClientIdx, DeviceIdx; 696 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); 697 return Clients[ClientIdx].initDevice(DeviceIdx); 698 } 699 700 int32_t RemoteClientManager::initRequires(int64_t RequiresFlags) { 701 for (auto &Client : Clients) 702 Client.initRequires(RequiresFlags); 703 704 return RequiresFlags; 705 } 706 707 __tgt_target_table *RemoteClientManager::loadBinary(int32_t DeviceId, 708 __tgt_device_image *Image) { 709 int32_t ClientIdx, DeviceIdx; 710 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); 711 return Clients[ClientIdx].loadBinary(DeviceIdx, Image); 712 } 713 714 int64_t RemoteClientManager::synchronize(int32_t DeviceId, 715 __tgt_async_info *AsyncInfo) { 716 int32_t ClientIdx, DeviceIdx; 717 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); 718 return Clients[ClientIdx].synchronize(DeviceIdx, AsyncInfo); 719 } 720 721 int32_t RemoteClientManager::isDataExchangeable(int32_t SrcDevId, 722 int32_t DstDevId) { 723 int32_t SrcClientIdx, SrcDeviceIdx, DstClientIdx, DstDeviceIdx; 724 std::tie(SrcClientIdx, SrcDeviceIdx) = mapDeviceId(SrcDevId); 725 std::tie(DstClientIdx, DstDeviceIdx) = mapDeviceId(DstDevId); 726 return Clients[SrcClientIdx].isDataExchangeable(SrcDeviceIdx, DstDeviceIdx); 727 } 728 729 void *RemoteClientManager::dataAlloc(int32_t DeviceId, int64_t Size, 730 void *HstPtr) { 731 int32_t ClientIdx, DeviceIdx; 732 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); 733 return Clients[ClientIdx].dataAlloc(DeviceIdx, Size, HstPtr); 734 } 735 736 int32_t RemoteClientManager::dataDelete(int32_t DeviceId, void *TgtPtr) { 737 int32_t ClientIdx, DeviceIdx; 738 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); 739 return Clients[ClientIdx].dataDelete(DeviceIdx, TgtPtr); 740 } 741 742 int32_t RemoteClientManager::dataSubmitAsync(int32_t DeviceId, void *TgtPtr, 743 void *HstPtr, int64_t Size, 744 __tgt_async_info *AsyncInfo) { 745 int32_t ClientIdx, DeviceIdx; 746 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); 747 return Clients[ClientIdx].dataSubmitAsync(DeviceIdx, TgtPtr, HstPtr, Size, 748 AsyncInfo); 749 } 750 751 int32_t RemoteClientManager::dataRetrieveAsync(int32_t DeviceId, void *HstPtr, 752 void *TgtPtr, int64_t Size, 753 __tgt_async_info *AsyncInfo) { 754 int32_t ClientIdx, DeviceIdx; 755 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); 756 return Clients[ClientIdx].dataRetrieveAsync(DeviceIdx, HstPtr, TgtPtr, Size, 757 AsyncInfo); 758 } 759 760 int32_t RemoteClientManager::dataExchangeAsync(int32_t SrcDevId, void *SrcPtr, 761 int32_t DstDevId, void *DstPtr, 762 int64_t Size, 763 __tgt_async_info *AsyncInfo) { 764 int32_t SrcClientIdx, SrcDeviceIdx, DstClientIdx, DstDeviceIdx; 765 std::tie(SrcClientIdx, SrcDeviceIdx) = mapDeviceId(SrcDevId); 766 std::tie(DstClientIdx, DstDeviceIdx) = mapDeviceId(DstDevId); 767 return Clients[SrcClientIdx].dataExchangeAsync( 768 SrcDeviceIdx, SrcPtr, DstDeviceIdx, DstPtr, Size, AsyncInfo); 769 } 770 771 int32_t RemoteClientManager::runTargetRegionAsync( 772 int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets, 773 int32_t ArgNum, __tgt_async_info *AsyncInfo) { 774 int32_t ClientIdx, DeviceIdx; 775 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); 776 return Clients[ClientIdx].runTargetRegionAsync( 777 DeviceIdx, TgtEntryPtr, TgtArgs, TgtOffsets, ArgNum, AsyncInfo); 778 } 779 780 int32_t RemoteClientManager::runTargetTeamRegionAsync( 781 int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets, 782 int32_t ArgNum, int32_t TeamNum, int32_t ThreadLimit, 783 uint64_t LoopTripCount, __tgt_async_info *AsyncInfo) { 784 int32_t ClientIdx, DeviceIdx; 785 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); 786 return Clients[ClientIdx].runTargetTeamRegionAsync( 787 DeviceIdx, TgtEntryPtr, TgtArgs, TgtOffsets, ArgNum, TeamNum, ThreadLimit, 788 LoopTripCount, AsyncInfo); 789 } 790