1 /* 2 Copyright (c) 2021-2022 Intel Corporation 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 #include "common/config.h" 18 19 #include <oneapi/tbb/task_arena.h> 20 #include <oneapi/tbb/concurrent_vector.h> 21 #include <oneapi/tbb/rw_mutex.h> 22 #include <oneapi/tbb/task_group.h> 23 #include <oneapi/tbb/parallel_for.h> 24 25 #include <oneapi/tbb/global_control.h> 26 27 #include "common/test.h" 28 #include "common/utils.h" 29 #include "common/utils_concurrency_limit.h" 30 #include "common/spin_barrier.h" 31 32 #include <stdlib.h> // C11/POSIX aligned_alloc 33 #include <random> 34 35 36 //! \file test_scheduler_mix.cpp 37 //! \brief Test for [scheduler.task_arena scheduler.task_scheduler_observer] specification 38 39 const std::uint64_t maxNumActions = 1 * 100 * 1000; 40 static std::atomic<std::uint64_t> globalNumActions{}; 41 42 //using Random = utils::FastRandom<>; 43 class Random { 44 struct State { 45 std::random_device rd; 46 std::mt19937 gen; 47 std::uniform_int_distribution<> dist; 48 49 State() : gen(rd()), dist(0, std::numeric_limits<unsigned short>::max()) {} 50 51 int get() { 52 return dist(gen); 53 } 54 }; 55 static thread_local State* mState; 56 tbb::concurrent_vector<State*> mStateList; 57 public: 58 ~Random() { 59 for (auto s : mStateList) { 60 delete s; 61 } 62 } 63 64 int get() { 65 auto& s = mState; 66 if (!s) { 67 s = new State; 68 mStateList.push_back(s); 69 } 70 return s->get(); 71 } 72 }; 73 74 thread_local Random::State* Random::mState = nullptr; 75 76 77 void* aligned_malloc(std::size_t alignment, std::size_t size) { 78 #if _WIN32 79 return _aligned_malloc(size, alignment); 80 #elif __unix__ || __APPLE__ 81 void* ptr{}; 82 int res = posix_memalign(&ptr, alignment, size); 83 CHECK(res == 0); 84 return ptr; 85 #else 86 return aligned_alloc(alignment, size); 87 #endif 88 } 89 90 void aligned_free(void* ptr) { 91 #if _WIN32 92 _aligned_free(ptr); 93 #else 94 free(ptr); 95 #endif 96 } 97 98 template <typename T, std::size_t Alignment> 99 class PtrRWMutex { 100 static const std::size_t maxThreads = (Alignment >> 1) - 1; 101 static const std::uintptr_t READER_MASK = maxThreads; // 7F.. 102 static const std::uintptr_t LOCKED = Alignment - 1; // FF.. 103 static const std::uintptr_t LOCKED_MASK = LOCKED; // FF.. 104 static const std::uintptr_t LOCK_PENDING = READER_MASK + 1; // 80.. 105 106 std::atomic<std::uintptr_t> mState; 107 108 T* pointer() { 109 return reinterpret_cast<T*>(state() & ~LOCKED_MASK); 110 } 111 112 std::uintptr_t state() { 113 return mState.load(std::memory_order_relaxed); 114 } 115 116 public: 117 class ScopedLock { 118 public: 119 constexpr ScopedLock() : mMutex(nullptr), mIsWriter(false) {} 120 //! Acquire lock on given mutex. 121 ScopedLock(PtrRWMutex& m, bool write = true) : mMutex(nullptr) { 122 CHECK_FAST(write == true); 123 acquire(m); 124 } 125 //! Release lock (if lock is held). 126 ~ScopedLock() { 127 if (mMutex) { 128 release(); 129 } 130 } 131 //! No Copy 132 ScopedLock(const ScopedLock&) = delete; 133 ScopedLock& operator=(const ScopedLock&) = delete; 134 135 //! Acquire lock on given mutex. 136 void acquire(PtrRWMutex& m) { 137 CHECK_FAST(mMutex == nullptr); 138 mIsWriter = true; 139 mMutex = &m; 140 mMutex->lock(); 141 } 142 143 //! Try acquire lock on given mutex. 144 bool tryAcquire(PtrRWMutex& m, bool write = true) { 145 bool succeed = write ? m.tryLock() : m.tryLockShared(); 146 if (succeed) { 147 mMutex = &m; 148 mIsWriter = write; 149 } 150 return succeed; 151 } 152 153 void clear() { 154 CHECK_FAST(mMutex != nullptr); 155 CHECK_FAST(mIsWriter); 156 PtrRWMutex* m = mMutex; 157 mMutex = nullptr; 158 m->clear(); 159 } 160 161 //! Release lock. 162 void release() { 163 CHECK_FAST(mMutex != nullptr); 164 PtrRWMutex* m = mMutex; 165 mMutex = nullptr; 166 167 if (mIsWriter) { 168 m->unlock(); 169 } 170 else { 171 m->unlockShared(); 172 } 173 } 174 protected: 175 PtrRWMutex* mMutex{}; 176 bool mIsWriter{}; 177 }; 178 179 bool trySet(T* ptr) { 180 auto p = reinterpret_cast<std::uintptr_t>(ptr); 181 CHECK_FAST((p & (Alignment - 1)) == 0); 182 if (!state()) { 183 std::uintptr_t expected = 0; 184 if (mState.compare_exchange_strong(expected, p)) { 185 return true; 186 } 187 } 188 return false; 189 } 190 191 void clear() { 192 CHECK_FAST((state() & LOCKED_MASK) == LOCKED); 193 mState.store(0, std::memory_order_relaxed); 194 } 195 196 bool tryLock() { 197 auto v = state(); 198 if (v == 0) { 199 return false; 200 } 201 CHECK_FAST((v & LOCKED_MASK) == LOCKED || (v & READER_MASK) < maxThreads); 202 if ((v & READER_MASK) == 0) { 203 if (mState.compare_exchange_strong(v, v | LOCKED)) { 204 return true; 205 } 206 } 207 return false; 208 } 209 210 bool tryLockShared() { 211 auto v = state(); 212 if (v == 0) { 213 return false; 214 } 215 CHECK_FAST((v & LOCKED_MASK) == LOCKED || (v & READER_MASK) < maxThreads); 216 if ((v & LOCKED_MASK) != LOCKED && (v & LOCK_PENDING) == 0) { 217 if (mState.compare_exchange_strong(v, v + 1)) { 218 return true; 219 } 220 } 221 return false; 222 } 223 224 void lock() { 225 auto v = state(); 226 mState.compare_exchange_strong(v, v | LOCK_PENDING); 227 while (!tryLock()) { 228 utils::yield(); 229 } 230 } 231 232 void unlock() { 233 auto v = state(); 234 CHECK_FAST((v & LOCKED_MASK) == LOCKED); 235 mState.store(v & ~LOCKED, std::memory_order_release); 236 } 237 238 void unlockShared() { 239 auto v = state(); 240 CHECK_FAST((v & LOCKED_MASK) != LOCKED); 241 CHECK_FAST((v & READER_MASK) > 0); 242 mState -= 1; 243 } 244 245 operator bool() const { 246 return pointer() != 0; 247 } 248 249 T* get() { 250 return pointer(); 251 } 252 }; 253 254 class Statistics { 255 public: 256 enum ACTION { 257 ArenaCreate, 258 ArenaDestroy, 259 ArenaAcquire, 260 skippedArenaCreate, 261 skippedArenaDestroy, 262 skippedArenaAcquire, 263 ParallelAlgorithm, 264 ArenaEnqueue, 265 ArenaExecute, 266 numActions 267 }; 268 269 static const char* const mStatNames[numActions]; 270 private: 271 struct StatType { 272 StatType() : mCounters() {} 273 std::array<std::uint64_t, numActions> mCounters; 274 }; 275 276 tbb::concurrent_vector<StatType*> mStatsList; 277 static thread_local StatType* mStats; 278 279 StatType& get() { 280 auto& s = mStats; 281 if (!s) { 282 s = new StatType; 283 mStatsList.push_back(s); 284 } 285 return *s; 286 } 287 public: 288 ~Statistics() { 289 for (auto s : mStatsList) { 290 delete s; 291 } 292 } 293 294 void notify(ACTION a) { 295 ++get().mCounters[a]; 296 } 297 298 void report() { 299 StatType summary; 300 for (auto& s : mStatsList) { 301 for (int i = 0; i < numActions; ++i) { 302 summary.mCounters[i] += s->mCounters[i]; 303 } 304 } 305 std::cout << std::endl << "Statistics:" << std::endl; 306 std::cout << "Total actions: " << globalNumActions << std::endl; 307 for (int i = 0; i < numActions; ++i) { 308 std::cout << mStatNames[i] << ": " << summary.mCounters[i] << std::endl; 309 } 310 } 311 }; 312 313 314 const char* const Statistics::mStatNames[Statistics::numActions] = { 315 "Arena create", "Arena destroy", "Arena acquire", 316 "Skipped arena create", "Skipped arena destroy", "Skipped arena acquire", 317 "Parallel algorithm", "Arena enqueue", "Arena execute" 318 }; 319 thread_local Statistics::StatType* Statistics::mStats; 320 321 static Statistics gStats; 322 323 class LifetimeTracker { 324 public: 325 LifetimeTracker() = default; 326 327 class Guard { 328 public: 329 Guard(LifetimeTracker* obj) { 330 if (!(obj->mReferences.load(std::memory_order_relaxed) & SHUTDOWN_FLAG)) { 331 if (obj->mReferences.fetch_add(REFERENCE_FLAG) & SHUTDOWN_FLAG) { 332 obj->mReferences.fetch_sub(REFERENCE_FLAG); 333 } else { 334 mObj = obj; 335 } 336 } 337 } 338 339 Guard(Guard&& ing) : mObj(ing.mObj) { 340 ing.mObj = nullptr; 341 } 342 343 ~Guard() { 344 if (mObj != nullptr) { 345 mObj->mReferences.fetch_sub(REFERENCE_FLAG); 346 } 347 } 348 349 bool continue_execution() { 350 return mObj != nullptr; 351 } 352 353 private: 354 LifetimeTracker* mObj{nullptr}; 355 }; 356 357 Guard makeGuard() { 358 return Guard(this); 359 } 360 361 void signalShutdown() { 362 mReferences.fetch_add(SHUTDOWN_FLAG); 363 } 364 365 void waitCompletion() { 366 utils::SpinWaitUntilEq(mReferences, SHUTDOWN_FLAG); 367 } 368 369 private: 370 friend class Guard; 371 static constexpr std::uintptr_t SHUTDOWN_FLAG = 1; 372 static constexpr std::uintptr_t REFERENCE_FLAG = 1 << 1; 373 std::atomic<std::uintptr_t> mReferences{}; 374 }; 375 376 class ArenaTable { 377 static const std::size_t maxArenas = 64; 378 static const std::size_t maxThreads = 1 << 9; 379 static const std::size_t arenaAligment = maxThreads << 1; 380 381 using ArenaPtrRWMutex = PtrRWMutex<tbb::task_arena, arenaAligment>; 382 std::array<ArenaPtrRWMutex, maxArenas> mArenaTable; 383 384 struct ThreadState { 385 bool lockedArenas[maxArenas]{}; 386 int arenaIdxStack[maxArenas]; 387 int level{}; 388 }; 389 390 LifetimeTracker mLifetimeTracker{}; 391 static thread_local ThreadState mThreadState; 392 393 template <typename F> 394 auto find_arena(std::size_t start, F f) -> decltype(f(std::declval<ArenaPtrRWMutex&>(), std::size_t{})) { 395 for (std::size_t idx = start, i = 0; i < maxArenas; ++i, idx = (idx + 1) % maxArenas) { 396 auto res = f(mArenaTable[idx], idx); 397 if (res) { 398 return res; 399 } 400 } 401 return {}; 402 } 403 404 public: 405 using ScopedLock = ArenaPtrRWMutex::ScopedLock; 406 407 void create(Random& rnd) { 408 auto guard = mLifetimeTracker.makeGuard(); 409 if (guard.continue_execution()) { 410 int num_threads = rnd.get() % utils::get_platform_max_threads() + 1; 411 unsigned int num_reserved = rnd.get() % num_threads; 412 tbb::task_arena::priority priorities[] = { tbb::task_arena::priority::low , tbb::task_arena::priority::normal, tbb::task_arena::priority::high }; 413 tbb::task_arena::priority priority = priorities[rnd.get() % 3]; 414 415 tbb::task_arena* a = new (aligned_malloc(arenaAligment, arenaAligment)) tbb::task_arena{ num_threads , num_reserved , priority }; 416 417 if (!find_arena(rnd.get() % maxArenas, [a](ArenaPtrRWMutex& arena, std::size_t) -> bool { 418 if (arena.trySet(a)) { 419 return true; 420 } 421 return false; 422 })) 423 { 424 gStats.notify(Statistics::skippedArenaCreate); 425 a->~task_arena(); 426 aligned_free(a); 427 } 428 } 429 } 430 431 void destroy(Random& rnd) { 432 auto guard = mLifetimeTracker.makeGuard(); 433 if (guard.continue_execution()) { 434 auto& ts = mThreadState; 435 if (!find_arena(rnd.get() % maxArenas, [&ts](ArenaPtrRWMutex& arena, std::size_t idx) { 436 if (!ts.lockedArenas[idx]) { 437 ScopedLock lock; 438 if (lock.tryAcquire(arena, true)) { 439 auto a = arena.get(); 440 lock.clear(); 441 a->~task_arena(); 442 aligned_free(a); 443 return true; 444 } 445 } 446 return false; 447 })) 448 { 449 gStats.notify(Statistics::skippedArenaDestroy); 450 } 451 } 452 } 453 454 void shutdown() { 455 mLifetimeTracker.signalShutdown(); 456 mLifetimeTracker.waitCompletion(); 457 find_arena(0, [](ArenaPtrRWMutex& arena, std::size_t) { 458 if (arena.get()) { 459 ScopedLock lock{ arena, true }; 460 auto a = arena.get(); 461 lock.clear(); 462 a->~task_arena(); 463 aligned_free(a); 464 } 465 return false; 466 }); 467 } 468 469 std::pair<tbb::task_arena*, std::size_t> acquire(Random& rnd, ScopedLock& lock) { 470 auto guard = mLifetimeTracker.makeGuard(); 471 472 tbb::task_arena* a{nullptr}; 473 std::size_t resIdx{}; 474 if (guard.continue_execution()) { 475 auto& ts = mThreadState; 476 a = find_arena(rnd.get() % maxArenas, 477 [&ts, &lock, &resIdx](ArenaPtrRWMutex& arena, std::size_t idx) -> tbb::task_arena* { 478 if (!ts.lockedArenas[idx]) { 479 if (lock.tryAcquire(arena, false)) { 480 ts.lockedArenas[idx] = true; 481 ts.arenaIdxStack[ts.level++] = int(idx); 482 resIdx = idx; 483 return arena.get(); 484 } 485 } 486 return nullptr; 487 }); 488 if (!a) { 489 gStats.notify(Statistics::skippedArenaAcquire); 490 } 491 } 492 return { a, resIdx }; 493 } 494 495 void release(ScopedLock& lock) { 496 auto& ts = mThreadState; 497 CHECK_FAST(ts.level > 0); 498 auto idx = ts.arenaIdxStack[--ts.level]; 499 CHECK_FAST(ts.lockedArenas[idx]); 500 ts.lockedArenas[idx] = false; 501 lock.release(); 502 } 503 }; 504 505 thread_local ArenaTable::ThreadState ArenaTable::mThreadState; 506 507 static ArenaTable arenaTable; 508 static Random threadRandom; 509 510 enum ACTIONS { 511 arena_create, 512 arena_destroy, 513 arena_action, 514 parallel_algorithm, 515 // TODO: 516 // observer_attach, 517 // observer_detach, 518 // flow_graph, 519 // task_group, 520 // resumable_tasks, 521 522 num_actions 523 }; 524 525 void global_actor(); 526 527 template <ACTIONS action> 528 struct actor; 529 530 template <> 531 struct actor<arena_create> { 532 static void do_it(Random& r) { 533 arenaTable.create(r); 534 } 535 }; 536 537 template <> 538 struct actor<arena_destroy> { 539 static void do_it(Random& r) { 540 arenaTable.destroy(r); 541 } 542 }; 543 544 template <> 545 struct actor<arena_action> { 546 static void do_it(Random& r) { 547 static thread_local std::size_t arenaLevel = 0; 548 ArenaTable::ScopedLock lock; 549 auto entry = arenaTable.acquire(r, lock); 550 if (entry.first) { 551 enum arena_actions { 552 arena_execute, 553 arena_enqueue, 554 num_arena_actions 555 }; 556 auto process = r.get() % 2; 557 auto body = [process] { 558 if (process) { 559 tbb::detail::d1::wait_context wctx{ 1 }; 560 tbb::task_group_context ctx; 561 tbb::this_task_arena::enqueue([&wctx] { wctx.release(); }); 562 tbb::detail::d1::wait(wctx, ctx); 563 } else { 564 global_actor(); 565 } 566 }; 567 switch (r.get() % (16*num_arena_actions)) { 568 case arena_execute: 569 if (entry.second > arenaLevel) { 570 gStats.notify(Statistics::ArenaExecute); 571 auto oldArenaLevel = arenaLevel; 572 arenaLevel = entry.second; 573 entry.first->execute(body); 574 arenaLevel = oldArenaLevel; 575 break; 576 } 577 utils_fallthrough; 578 case arena_enqueue: 579 utils_fallthrough; 580 default: 581 gStats.notify(Statistics::ArenaEnqueue); 582 entry.first->enqueue([] { global_actor(); }); 583 break; 584 } 585 arenaTable.release(lock); 586 } 587 } 588 }; 589 590 template <> 591 struct actor<parallel_algorithm> { 592 static void do_it(Random& rnd) { 593 enum PARTITIONERS { 594 simpl_part, 595 auto_part, 596 aff_part, 597 static_part, 598 num_parts 599 }; 600 int sz = rnd.get() % 10000; 601 auto doGlbAction = rnd.get() % 1000 == 42; 602 auto body = [doGlbAction, sz](int i) { 603 if (i == sz / 2 && doGlbAction) { 604 global_actor(); 605 } 606 }; 607 608 switch (rnd.get() % num_parts) { 609 case simpl_part: 610 tbb::parallel_for(0, sz, body, tbb::simple_partitioner{}); break; 611 case auto_part: 612 tbb::parallel_for(0, sz, body, tbb::auto_partitioner{}); break; 613 case aff_part: 614 { 615 tbb::affinity_partitioner aff; 616 tbb::parallel_for(0, sz, body, aff); break; 617 } 618 case static_part: 619 tbb::parallel_for(0, sz, body, tbb::static_partitioner{}); break; 620 } 621 } 622 }; 623 624 void global_actor() { 625 static thread_local std::uint64_t localNumActions{}; 626 627 while (globalNumActions < maxNumActions) { 628 auto& rnd = threadRandom; 629 switch (rnd.get() % num_actions) { 630 case arena_create: gStats.notify(Statistics::ArenaCreate); actor<arena_create>::do_it(rnd); break; 631 case arena_destroy: gStats.notify(Statistics::ArenaDestroy); actor<arena_destroy>::do_it(rnd); break; 632 case arena_action: gStats.notify(Statistics::ArenaAcquire); actor<arena_action>::do_it(rnd); break; 633 case parallel_algorithm: gStats.notify(Statistics::ParallelAlgorithm); actor<parallel_algorithm>::do_it(rnd); break; 634 } 635 636 if (++localNumActions == 100) { 637 localNumActions = 0; 638 globalNumActions += 100; 639 640 // TODO: Enable statistics 641 // static std::mutex mutex; 642 // std::lock_guard<std::mutex> lock{ mutex }; 643 // std::cout << globalNumActions << "\r" << std::flush; 644 } 645 } 646 } 647 648 #if TBB_USE_EXCEPTIONS 649 //! \brief \ref stress 650 TEST_CASE("Stress test with mixing functionality") { 651 // TODO add thread recreation 652 // TODO: Enable statistics 653 tbb::task_scheduler_handle handle{ tbb::attach{} }; 654 655 const std::size_t numExtraThreads = 16; 656 utils::SpinBarrier startBarrier{numExtraThreads}; 657 utils::NativeParallelFor(numExtraThreads, [&startBarrier](std::size_t) { 658 startBarrier.wait(); 659 global_actor(); 660 }); 661 662 arenaTable.shutdown(); 663 664 tbb::finalize(handle); 665 666 // gStats.report(); 667 } 668 #endif 669