1 /* 2 Copyright (c) 2021 Intel Corporation 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 #define TBB_PREVIEW_MUTEXES 1 18 #define TBB_PREVIEW_WAITING_FOR_WORKERS 1 19 #define TBB_PREVIEW_TASK_GROUP_EXTENSIONS 1 20 21 #include "common/config.h" 22 23 #include <oneapi/tbb/task_arena.h> 24 #include <oneapi/tbb/concurrent_vector.h> 25 #include <oneapi/tbb/rw_mutex.h> 26 #include <oneapi/tbb/task_group.h> 27 #include <oneapi/tbb/parallel_for.h> 28 29 #include <oneapi/tbb/global_control.h> 30 31 #include "common/test.h" 32 #include "common/utils.h" 33 #include "common/utils_concurrency_limit.h" 34 #include "common/spin_barrier.h" 35 36 #include <stdlib.h> // C11/POSIX aligned_alloc 37 #include <random> 38 39 40 //! \file test_scheduler_mix.cpp 41 //! \brief Test for [scheduler.task_arena scheduler.task_scheduler_observer] specification 42 43 const std::uint64_t maxNumActions = 1 * 100 * 1000; 44 static std::atomic<std::uint64_t> globalNumActions{}; 45 46 //using Random = utils::FastRandom<>; 47 class Random { 48 struct State { 49 std::random_device rd; 50 std::mt19937 gen; 51 std::uniform_int_distribution<> dist; 52 53 State() : gen(rd()), dist(0, std::numeric_limits<unsigned short>::max()) {} 54 55 int get() { 56 return dist(gen); 57 } 58 }; 59 static thread_local State* mState; 60 tbb::concurrent_vector<State*> mStateList; 61 public: 62 ~Random() { 63 for (auto s : mStateList) { 64 delete s; 65 } 66 } 67 68 int get() { 69 auto& s = mState; 70 if (!s) { 71 s = new State; 72 mStateList.push_back(s); 73 } 74 return s->get(); 75 } 76 }; 77 78 thread_local Random::State* Random::mState = nullptr; 79 80 81 void* aligned_malloc(std::size_t alignment, std::size_t size) { 82 #if _WIN32 83 return _aligned_malloc(size, alignment); 84 #elif __unix__ || __APPLE__ 85 void* ptr{}; 86 int res = posix_memalign(&ptr, alignment, size); 87 CHECK(res == 0); 88 return ptr; 89 #else 90 return aligned_alloc(alignment, size); 91 #endif 92 } 93 94 void aligned_free(void* ptr) { 95 #if _WIN32 96 _aligned_free(ptr); 97 #else 98 free(ptr); 99 #endif 100 } 101 102 template <typename T, std::size_t Alignment> 103 class PtrRWMutex { 104 static const std::size_t maxThreads = (Alignment >> 1) - 1; 105 static const std::uintptr_t READER_MASK = maxThreads; // 7F.. 106 static const std::uintptr_t LOCKED = Alignment - 1; // FF.. 107 static const std::uintptr_t LOCKED_MASK = LOCKED; // FF.. 108 static const std::uintptr_t LOCK_PENDING = READER_MASK + 1; // 80.. 109 110 std::atomic<std::uintptr_t> mState; 111 112 T* pointer() { 113 return reinterpret_cast<T*>(state() & ~LOCKED_MASK); 114 } 115 116 std::uintptr_t state() { 117 return mState.load(std::memory_order_relaxed); 118 } 119 120 public: 121 class ScopedLock { 122 public: 123 constexpr ScopedLock() : mMutex(nullptr), mIsWriter(false) {} 124 //! Acquire lock on given mutex. 125 ScopedLock(PtrRWMutex& m, bool write = true) : mMutex(nullptr) { 126 CHECK_FAST(write == true); 127 acquire(m); 128 } 129 //! Release lock (if lock is held). 130 ~ScopedLock() { 131 if (mMutex) { 132 release(); 133 } 134 } 135 //! No Copy 136 ScopedLock(const ScopedLock&) = delete; 137 ScopedLock& operator=(const ScopedLock&) = delete; 138 139 //! Acquire lock on given mutex. 140 void acquire(PtrRWMutex& m) { 141 CHECK_FAST(mMutex == nullptr); 142 mIsWriter = true; 143 mMutex = &m; 144 mMutex->lock(); 145 } 146 147 //! Try acquire lock on given mutex. 148 bool tryAcquire(PtrRWMutex& m, bool write = true) { 149 bool succeed = write ? m.tryLock() : m.tryLockShared(); 150 if (succeed) { 151 mMutex = &m; 152 mIsWriter = write; 153 } 154 return succeed; 155 } 156 157 void clear() { 158 CHECK_FAST(mMutex != nullptr); 159 CHECK_FAST(mIsWriter); 160 PtrRWMutex* m = mMutex; 161 mMutex = nullptr; 162 m->clear(); 163 } 164 165 //! Release lock. 166 void release() { 167 CHECK_FAST(mMutex != nullptr); 168 PtrRWMutex* m = mMutex; 169 mMutex = nullptr; 170 171 if (mIsWriter) { 172 m->unlock(); 173 } 174 else { 175 m->unlockShared(); 176 } 177 } 178 protected: 179 PtrRWMutex* mMutex{}; 180 bool mIsWriter{}; 181 }; 182 183 bool trySet(T* ptr) { 184 auto p = reinterpret_cast<std::uintptr_t>(ptr); 185 CHECK_FAST((p & (Alignment - 1)) == 0); 186 if (!state()) { 187 std::uintptr_t expected = 0; 188 if (mState.compare_exchange_strong(expected, p)) { 189 return true; 190 } 191 } 192 return false; 193 } 194 195 void clear() { 196 CHECK_FAST((state() & LOCKED_MASK) == LOCKED); 197 mState.store(0, std::memory_order_relaxed); 198 } 199 200 bool tryLock() { 201 auto v = state(); 202 if (v == 0) { 203 return false; 204 } 205 CHECK_FAST((v & LOCKED_MASK) == LOCKED || (v & READER_MASK) < maxThreads); 206 if ((v & READER_MASK) == 0) { 207 if (mState.compare_exchange_strong(v, v | LOCKED)) { 208 return true; 209 } 210 } 211 return false; 212 } 213 214 bool tryLockShared() { 215 auto v = state(); 216 if (v == 0) { 217 return false; 218 } 219 CHECK_FAST((v & LOCKED_MASK) == LOCKED || (v & READER_MASK) < maxThreads); 220 if ((v & LOCKED_MASK) != LOCKED && (v & LOCK_PENDING) == 0) { 221 if (mState.compare_exchange_strong(v, v + 1)) { 222 return true; 223 } 224 } 225 return false; 226 } 227 228 void lock() { 229 auto v = state(); 230 mState.compare_exchange_strong(v, v | LOCK_PENDING); 231 while (!tryLock()) { 232 utils::yield(); 233 } 234 } 235 236 void unlock() { 237 auto v = state(); 238 CHECK_FAST((v & LOCKED_MASK) == LOCKED); 239 mState.store(v & ~LOCKED, std::memory_order_release); 240 } 241 242 void unlockShared() { 243 auto v = state(); 244 CHECK_FAST((v & LOCKED_MASK) != LOCKED); 245 CHECK_FAST((v & READER_MASK) > 0); 246 mState -= 1; 247 } 248 249 operator bool() const { 250 return pointer() != 0; 251 } 252 253 T* get() { 254 return pointer(); 255 } 256 }; 257 258 class Statistics { 259 public: 260 enum ACTION { 261 ArenaCreate, 262 ArenaDestroy, 263 ArenaAcquire, 264 skippedArenaCreate, 265 skippedArenaDestroy, 266 skippedArenaAcquire, 267 ParallelAlgorithm, 268 ArenaEnqueue, 269 ArenaExecute, 270 numActions 271 }; 272 273 static const char* const mStatNames[numActions]; 274 private: 275 struct StatType { 276 StatType() : mCounters() {} 277 std::array<std::uint64_t, numActions> mCounters; 278 }; 279 280 tbb::concurrent_vector<StatType*> mStatsList; 281 static thread_local StatType* mStats; 282 283 StatType& get() { 284 auto& s = mStats; 285 if (!s) { 286 s = new StatType; 287 mStatsList.push_back(s); 288 } 289 return *s; 290 } 291 public: 292 ~Statistics() { 293 for (auto s : mStatsList) { 294 delete s; 295 } 296 } 297 298 void notify(ACTION a) { 299 ++get().mCounters[a]; 300 } 301 302 void report() { 303 StatType summary; 304 for (auto& s : mStatsList) { 305 for (int i = 0; i < numActions; ++i) { 306 summary.mCounters[i] += s->mCounters[i]; 307 } 308 } 309 std::cout << std::endl << "Statistics:" << std::endl; 310 std::cout << "Total actions: " << globalNumActions << std::endl; 311 for (int i = 0; i < numActions; ++i) { 312 std::cout << mStatNames[i] << ": " << summary.mCounters[i] << std::endl; 313 } 314 } 315 }; 316 317 318 const char* const Statistics::mStatNames[Statistics::numActions] = { 319 "Arena create", "Arena destroy", "Arena acquire", 320 "Skipped arena create", "Skipped arena destroy", "Skipped arena acquire", 321 "Parallel algorithm", "Arena enqueue", "Arena execute" 322 }; 323 thread_local Statistics::StatType* Statistics::mStats; 324 325 static Statistics gStats; 326 327 class LifetimeTracker { 328 public: 329 LifetimeTracker() = default; 330 331 class Guard { 332 public: 333 Guard(LifetimeTracker* obj) { 334 if (!(obj->mReferences.load(std::memory_order_relaxed) & SHUTDOWN_FLAG)) { 335 if (obj->mReferences.fetch_add(REFERENCE_FLAG) & SHUTDOWN_FLAG) { 336 obj->mReferences.fetch_sub(REFERENCE_FLAG); 337 } else { 338 mObj = obj; 339 } 340 } 341 } 342 343 Guard(Guard&& ing) : mObj(ing.mObj) { 344 ing.mObj = nullptr; 345 } 346 347 ~Guard() { 348 if (mObj != nullptr) { 349 mObj->mReferences.fetch_sub(REFERENCE_FLAG); 350 } 351 } 352 353 bool continue_execution() { 354 return mObj != nullptr; 355 } 356 357 private: 358 LifetimeTracker* mObj{nullptr}; 359 }; 360 361 Guard makeGuard() { 362 return Guard(this); 363 } 364 365 void signalShutdown() { 366 mReferences.fetch_add(SHUTDOWN_FLAG); 367 } 368 369 void waitCompletion() { 370 utils::SpinWaitUntilEq(mReferences, SHUTDOWN_FLAG); 371 } 372 373 private: 374 friend class Guard; 375 static constexpr std::uintptr_t SHUTDOWN_FLAG = 1; 376 static constexpr std::uintptr_t REFERENCE_FLAG = 1 << 1; 377 std::atomic<std::uintptr_t> mReferences{}; 378 }; 379 380 class ArenaTable { 381 static const std::size_t maxArenas = 64; 382 static const std::size_t maxThreads = 1 << 9; 383 static const std::size_t arenaAligment = maxThreads << 1; 384 385 using ArenaPtrRWMutex = PtrRWMutex<tbb::task_arena, arenaAligment>; 386 std::array<ArenaPtrRWMutex, maxArenas> mArenaTable; 387 388 struct ThreadState { 389 bool lockedArenas[maxArenas]{}; 390 int arenaIdxStack[maxArenas]; 391 int level{}; 392 }; 393 394 LifetimeTracker mLifetimeTracker{}; 395 static thread_local ThreadState mThreadState; 396 397 template <typename F> 398 auto find_arena(std::size_t start, F f) -> decltype(f(std::declval<ArenaPtrRWMutex&>(), std::size_t{})) { 399 for (std::size_t idx = start, i = 0; i < maxArenas; ++i, idx = (idx + 1) % maxArenas) { 400 auto res = f(mArenaTable[idx], idx); 401 if (res) { 402 return res; 403 } 404 } 405 return {}; 406 } 407 408 public: 409 using ScopedLock = ArenaPtrRWMutex::ScopedLock; 410 411 void create(Random& rnd) { 412 auto guard = mLifetimeTracker.makeGuard(); 413 if (guard.continue_execution()) { 414 int num_threads = rnd.get() % utils::get_platform_max_threads() + 1; 415 unsigned int num_reserved = rnd.get() % num_threads; 416 tbb::task_arena::priority priorities[] = { tbb::task_arena::priority::low , tbb::task_arena::priority::normal, tbb::task_arena::priority::high }; 417 tbb::task_arena::priority priority = priorities[rnd.get() % 3]; 418 419 tbb::task_arena* a = new (aligned_malloc(arenaAligment, arenaAligment)) tbb::task_arena{ num_threads , num_reserved , priority }; 420 421 if (!find_arena(rnd.get() % maxArenas, [a](ArenaPtrRWMutex& arena, std::size_t) -> bool { 422 if (arena.trySet(a)) { 423 return true; 424 } 425 return false; 426 })) 427 { 428 gStats.notify(Statistics::skippedArenaCreate); 429 a->~task_arena(); 430 aligned_free(a); 431 } 432 } 433 } 434 435 void destroy(Random& rnd) { 436 auto guard = mLifetimeTracker.makeGuard(); 437 if (guard.continue_execution()) { 438 auto& ts = mThreadState; 439 if (!find_arena(rnd.get() % maxArenas, [&ts](ArenaPtrRWMutex& arena, std::size_t idx) { 440 if (!ts.lockedArenas[idx]) { 441 ScopedLock lock; 442 if (lock.tryAcquire(arena, true)) { 443 auto a = arena.get(); 444 lock.clear(); 445 a->~task_arena(); 446 aligned_free(a); 447 return true; 448 } 449 } 450 return false; 451 })) 452 { 453 gStats.notify(Statistics::skippedArenaDestroy); 454 } 455 } 456 } 457 458 void shutdown() { 459 mLifetimeTracker.signalShutdown(); 460 mLifetimeTracker.waitCompletion(); 461 find_arena(0, [](ArenaPtrRWMutex& arena, std::size_t) { 462 if (arena.get()) { 463 ScopedLock lock{ arena, true }; 464 auto a = arena.get(); 465 lock.clear(); 466 a->~task_arena(); 467 aligned_free(a); 468 } 469 return false; 470 }); 471 } 472 473 std::pair<tbb::task_arena*, std::size_t> acquire(Random& rnd, ScopedLock& lock) { 474 auto guard = mLifetimeTracker.makeGuard(); 475 476 tbb::task_arena* a{nullptr}; 477 std::size_t resIdx{}; 478 if (guard.continue_execution()) { 479 auto& ts = mThreadState; 480 a = find_arena(rnd.get() % maxArenas, 481 [&ts, &lock, &resIdx](ArenaPtrRWMutex& arena, std::size_t idx) -> tbb::task_arena* { 482 if (!ts.lockedArenas[idx]) { 483 if (lock.tryAcquire(arena, false)) { 484 ts.lockedArenas[idx] = true; 485 ts.arenaIdxStack[ts.level++] = int(idx); 486 resIdx = idx; 487 return arena.get(); 488 } 489 } 490 return nullptr; 491 }); 492 if (!a) { 493 gStats.notify(Statistics::skippedArenaAcquire); 494 } 495 } 496 return { a, resIdx }; 497 } 498 499 void release(ScopedLock& lock) { 500 auto& ts = mThreadState; 501 CHECK_FAST(ts.level > 0); 502 auto idx = ts.arenaIdxStack[--ts.level]; 503 CHECK_FAST(ts.lockedArenas[idx]); 504 ts.lockedArenas[idx] = false; 505 lock.release(); 506 } 507 }; 508 509 thread_local ArenaTable::ThreadState ArenaTable::mThreadState; 510 511 static ArenaTable arenaTable; 512 static Random threadRandom; 513 514 enum ACTIONS { 515 arena_create, 516 arena_destroy, 517 arena_action, 518 parallel_algorithm, 519 // TODO: 520 // observer_attach, 521 // observer_detach, 522 // flow_graph, 523 // task_group, 524 // resumable_tasks, 525 526 num_actions 527 }; 528 529 void global_actor(); 530 531 template <ACTIONS action> 532 struct actor; 533 534 template <> 535 struct actor<arena_create> { 536 static void do_it(Random& r) { 537 arenaTable.create(r); 538 } 539 }; 540 541 template <> 542 struct actor<arena_destroy> { 543 static void do_it(Random& r) { 544 arenaTable.destroy(r); 545 } 546 }; 547 548 template <> 549 struct actor<arena_action> { 550 static void do_it(Random& r) { 551 static thread_local std::size_t arenaLevel = 0; 552 ArenaTable::ScopedLock lock; 553 auto entry = arenaTable.acquire(r, lock); 554 if (entry.first) { 555 enum arena_actions { 556 arena_execute, 557 arena_enqueue, 558 num_arena_actions 559 }; 560 auto process = r.get() % 2; 561 auto body = [process] { 562 if (process) { 563 tbb::detail::d1::wait_context wctx{ 1 }; 564 tbb::task_group_context ctx; 565 tbb::this_task_arena::enqueue([&wctx] { wctx.release(); }); 566 tbb::detail::d1::wait(wctx, ctx); 567 } else { 568 global_actor(); 569 } 570 }; 571 switch (r.get() % (16*num_arena_actions)) { 572 case arena_execute: 573 if (entry.second > arenaLevel) { 574 gStats.notify(Statistics::ArenaExecute); 575 auto oldArenaLevel = arenaLevel; 576 arenaLevel = entry.second; 577 entry.first->execute(body); 578 arenaLevel = oldArenaLevel; 579 break; 580 } 581 utils_fallthrough; 582 case arena_enqueue: 583 utils_fallthrough; 584 default: 585 gStats.notify(Statistics::ArenaEnqueue); 586 entry.first->enqueue([] { global_actor(); }); 587 break; 588 } 589 arenaTable.release(lock); 590 } 591 } 592 }; 593 594 template <> 595 struct actor<parallel_algorithm> { 596 static void do_it(Random& rnd) { 597 enum PARTITIONERS { 598 simpl_part, 599 auto_part, 600 aff_part, 601 static_part, 602 num_parts 603 }; 604 int sz = rnd.get() % 10000; 605 auto doGlbAction = rnd.get() % 1000 == 42; 606 auto body = [doGlbAction, sz](int i) { 607 if (i == sz / 2 && doGlbAction) { 608 global_actor(); 609 } 610 }; 611 612 switch (rnd.get() % num_parts) { 613 case simpl_part: 614 tbb::parallel_for(0, sz, body, tbb::simple_partitioner{}); break; 615 case auto_part: 616 tbb::parallel_for(0, sz, body, tbb::auto_partitioner{}); break; 617 case aff_part: 618 { 619 tbb::affinity_partitioner aff; 620 tbb::parallel_for(0, sz, body, aff); break; 621 } 622 case static_part: 623 tbb::parallel_for(0, sz, body, tbb::static_partitioner{}); break; 624 } 625 } 626 }; 627 628 void global_actor() { 629 static thread_local std::uint64_t localNumActions{}; 630 631 while (globalNumActions < maxNumActions) { 632 auto& rnd = threadRandom; 633 switch (rnd.get() % num_actions) { 634 case arena_create: gStats.notify(Statistics::ArenaCreate); actor<arena_create>::do_it(rnd); break; 635 case arena_destroy: gStats.notify(Statistics::ArenaDestroy); actor<arena_destroy>::do_it(rnd); break; 636 case arena_action: gStats.notify(Statistics::ArenaAcquire); actor<arena_action>::do_it(rnd); break; 637 case parallel_algorithm: gStats.notify(Statistics::ParallelAlgorithm); actor<parallel_algorithm>::do_it(rnd); break; 638 } 639 640 if (++localNumActions == 100) { 641 localNumActions = 0; 642 globalNumActions += 100; 643 644 // TODO: Enable statistics 645 // static std::mutex mutex; 646 // std::lock_guard<std::mutex> lock{ mutex }; 647 // std::cout << globalNumActions << "\r" << std::flush; 648 } 649 } 650 } 651 652 #if TBB_USE_EXCEPTIONS 653 //! \brief \ref stress 654 TEST_CASE("Stress test with mixing functionality") { 655 // TODO add thread recreation 656 // TODO: Enable statistics 657 // tbb::task_scheduler_handle handle = tbb::task_scheduler_handle::get(); 658 659 const std::size_t numExtraThreads = 16; 660 utils::SpinBarrier startBarrier{numExtraThreads}; 661 utils::NativeParallelFor(numExtraThreads, [&startBarrier](std::size_t) { 662 startBarrier.wait(); 663 global_actor(); 664 }); 665 666 arenaTable.shutdown(); 667 668 // tbb::finalize(handle, std::nothrow_t{}); 669 670 // gStats.report(); 671 } 672 #endif 673