1 /*
2 Copyright (c) 2021-2022 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17 #include "common/config.h"
18
19 #include <oneapi/tbb/task_arena.h>
20 #include <oneapi/tbb/concurrent_vector.h>
21 #include <oneapi/tbb/rw_mutex.h>
22 #include <oneapi/tbb/task_group.h>
23 #include <oneapi/tbb/parallel_for.h>
24
25 #include <oneapi/tbb/global_control.h>
26
27 #include "common/test.h"
28 #include "common/utils.h"
29 #include "common/utils_concurrency_limit.h"
30 #include "common/spin_barrier.h"
31
32 #include <stdlib.h> // C11/POSIX aligned_alloc
33 #include <random>
34
35
36 //! \file test_scheduler_mix.cpp
37 //! \brief Test for [scheduler.task_arena scheduler.task_scheduler_observer] specification
38
39 const std::uint64_t maxNumActions = 1 * 100 * 1000;
40 static std::atomic<std::uint64_t> globalNumActions{};
41
42 //using Random = utils::FastRandom<>;
43 class Random {
44 struct State {
45 std::random_device rd;
46 std::mt19937 gen;
47 std::uniform_int_distribution<> dist;
48
StateRandom::State49 State() : gen(rd()), dist(0, std::numeric_limits<unsigned short>::max()) {}
50
getRandom::State51 int get() {
52 return dist(gen);
53 }
54 };
55 static thread_local State* mState;
56 tbb::concurrent_vector<State*> mStateList;
57 public:
~Random()58 ~Random() {
59 for (auto s : mStateList) {
60 delete s;
61 }
62 }
63
get()64 int get() {
65 auto& s = mState;
66 if (!s) {
67 s = new State;
68 mStateList.push_back(s);
69 }
70 return s->get();
71 }
72 };
73
74 thread_local Random::State* Random::mState = nullptr;
75
76
aligned_malloc(std::size_t alignment,std::size_t size)77 void* aligned_malloc(std::size_t alignment, std::size_t size) {
78 #if _WIN32
79 return _aligned_malloc(size, alignment);
80 #elif __unix__ || __APPLE__
81 void* ptr{};
82 int res = posix_memalign(&ptr, alignment, size);
83 CHECK(res == 0);
84 return ptr;
85 #else
86 return aligned_alloc(alignment, size);
87 #endif
88 }
89
aligned_free(void * ptr)90 void aligned_free(void* ptr) {
91 #if _WIN32
92 _aligned_free(ptr);
93 #else
94 free(ptr);
95 #endif
96 }
97
98 template <typename T, std::size_t Alignment>
99 class PtrRWMutex {
100 static const std::size_t maxThreads = (Alignment >> 1) - 1;
101 static const std::uintptr_t READER_MASK = maxThreads; // 7F..
102 static const std::uintptr_t LOCKED = Alignment - 1; // FF..
103 static const std::uintptr_t LOCKED_MASK = LOCKED; // FF..
104 static const std::uintptr_t LOCK_PENDING = READER_MASK + 1; // 80..
105
106 std::atomic<std::uintptr_t> mState;
107
pointer()108 T* pointer() {
109 return reinterpret_cast<T*>(state() & ~LOCKED_MASK);
110 }
111
state()112 std::uintptr_t state() {
113 return mState.load(std::memory_order_relaxed);
114 }
115
116 public:
117 class ScopedLock {
118 public:
ScopedLock()119 constexpr ScopedLock() : mMutex(nullptr), mIsWriter(false) {}
120 //! Acquire lock on given mutex.
ScopedLock(PtrRWMutex & m,bool write=true)121 ScopedLock(PtrRWMutex& m, bool write = true) : mMutex(nullptr) {
122 CHECK_FAST(write == true);
123 acquire(m);
124 }
125 //! Release lock (if lock is held).
~ScopedLock()126 ~ScopedLock() {
127 if (mMutex) {
128 release();
129 }
130 }
131 //! No Copy
132 ScopedLock(const ScopedLock&) = delete;
133 ScopedLock& operator=(const ScopedLock&) = delete;
134
135 //! Acquire lock on given mutex.
acquire(PtrRWMutex & m)136 void acquire(PtrRWMutex& m) {
137 CHECK_FAST(mMutex == nullptr);
138 mIsWriter = true;
139 mMutex = &m;
140 mMutex->lock();
141 }
142
143 //! Try acquire lock on given mutex.
tryAcquire(PtrRWMutex & m,bool write=true)144 bool tryAcquire(PtrRWMutex& m, bool write = true) {
145 bool succeed = write ? m.tryLock() : m.tryLockShared();
146 if (succeed) {
147 mMutex = &m;
148 mIsWriter = write;
149 }
150 return succeed;
151 }
152
clear()153 void clear() {
154 CHECK_FAST(mMutex != nullptr);
155 CHECK_FAST(mIsWriter);
156 PtrRWMutex* m = mMutex;
157 mMutex = nullptr;
158 m->clear();
159 }
160
161 //! Release lock.
release()162 void release() {
163 CHECK_FAST(mMutex != nullptr);
164 PtrRWMutex* m = mMutex;
165 mMutex = nullptr;
166
167 if (mIsWriter) {
168 m->unlock();
169 }
170 else {
171 m->unlockShared();
172 }
173 }
174 protected:
175 PtrRWMutex* mMutex{};
176 bool mIsWriter{};
177 };
178
trySet(T * ptr)179 bool trySet(T* ptr) {
180 auto p = reinterpret_cast<std::uintptr_t>(ptr);
181 CHECK_FAST((p & (Alignment - 1)) == 0);
182 if (!state()) {
183 std::uintptr_t expected = 0;
184 if (mState.compare_exchange_strong(expected, p)) {
185 return true;
186 }
187 }
188 return false;
189 }
190
clear()191 void clear() {
192 CHECK_FAST((state() & LOCKED_MASK) == LOCKED);
193 mState.store(0, std::memory_order_relaxed);
194 }
195
tryLock()196 bool tryLock() {
197 auto v = state();
198 if (v == 0) {
199 return false;
200 }
201 CHECK_FAST((v & LOCKED_MASK) == LOCKED || (v & READER_MASK) < maxThreads);
202 if ((v & READER_MASK) == 0) {
203 if (mState.compare_exchange_strong(v, v | LOCKED)) {
204 return true;
205 }
206 }
207 return false;
208 }
209
tryLockShared()210 bool tryLockShared() {
211 auto v = state();
212 if (v == 0) {
213 return false;
214 }
215 CHECK_FAST((v & LOCKED_MASK) == LOCKED || (v & READER_MASK) < maxThreads);
216 if ((v & LOCKED_MASK) != LOCKED && (v & LOCK_PENDING) == 0) {
217 if (mState.compare_exchange_strong(v, v + 1)) {
218 return true;
219 }
220 }
221 return false;
222 }
223
lock()224 void lock() {
225 auto v = state();
226 mState.compare_exchange_strong(v, v | LOCK_PENDING);
227 while (!tryLock()) {
228 utils::yield();
229 }
230 }
231
unlock()232 void unlock() {
233 auto v = state();
234 CHECK_FAST((v & LOCKED_MASK) == LOCKED);
235 mState.store(v & ~LOCKED, std::memory_order_release);
236 }
237
unlockShared()238 void unlockShared() {
239 auto v = state();
240 CHECK_FAST((v & LOCKED_MASK) != LOCKED);
241 CHECK_FAST((v & READER_MASK) > 0);
242 mState -= 1;
243 }
244
operator bool() const245 operator bool() const {
246 return pointer() != 0;
247 }
248
get()249 T* get() {
250 return pointer();
251 }
252 };
253
254 class Statistics {
255 public:
256 enum ACTION {
257 ArenaCreate,
258 ArenaDestroy,
259 ArenaAcquire,
260 skippedArenaCreate,
261 skippedArenaDestroy,
262 skippedArenaAcquire,
263 ParallelAlgorithm,
264 ArenaEnqueue,
265 ArenaExecute,
266 numActions
267 };
268
269 static const char* const mStatNames[numActions];
270 private:
271 struct StatType {
StatTypeStatistics::StatType272 StatType() : mCounters() {}
273 std::array<std::uint64_t, numActions> mCounters;
274 };
275
276 tbb::concurrent_vector<StatType*> mStatsList;
277 static thread_local StatType* mStats;
278
get()279 StatType& get() {
280 auto& s = mStats;
281 if (!s) {
282 s = new StatType;
283 mStatsList.push_back(s);
284 }
285 return *s;
286 }
287 public:
~Statistics()288 ~Statistics() {
289 for (auto s : mStatsList) {
290 delete s;
291 }
292 }
293
notify(ACTION a)294 void notify(ACTION a) {
295 ++get().mCounters[a];
296 }
297
report()298 void report() {
299 StatType summary;
300 for (auto& s : mStatsList) {
301 for (int i = 0; i < numActions; ++i) {
302 summary.mCounters[i] += s->mCounters[i];
303 }
304 }
305 std::cout << std::endl << "Statistics:" << std::endl;
306 std::cout << "Total actions: " << globalNumActions << std::endl;
307 for (int i = 0; i < numActions; ++i) {
308 std::cout << mStatNames[i] << ": " << summary.mCounters[i] << std::endl;
309 }
310 }
311 };
312
313
314 const char* const Statistics::mStatNames[Statistics::numActions] = {
315 "Arena create", "Arena destroy", "Arena acquire",
316 "Skipped arena create", "Skipped arena destroy", "Skipped arena acquire",
317 "Parallel algorithm", "Arena enqueue", "Arena execute"
318 };
319 thread_local Statistics::StatType* Statistics::mStats;
320
321 static Statistics gStats;
322
323 class LifetimeTracker {
324 public:
325 LifetimeTracker() = default;
326
327 class Guard {
328 public:
Guard(LifetimeTracker * obj)329 Guard(LifetimeTracker* obj) {
330 if (!(obj->mReferences.load(std::memory_order_relaxed) & SHUTDOWN_FLAG)) {
331 if (obj->mReferences.fetch_add(REFERENCE_FLAG) & SHUTDOWN_FLAG) {
332 obj->mReferences.fetch_sub(REFERENCE_FLAG);
333 } else {
334 mObj = obj;
335 }
336 }
337 }
338
Guard(Guard && ing)339 Guard(Guard&& ing) : mObj(ing.mObj) {
340 ing.mObj = nullptr;
341 }
342
~Guard()343 ~Guard() {
344 if (mObj != nullptr) {
345 mObj->mReferences.fetch_sub(REFERENCE_FLAG);
346 }
347 }
348
continue_execution()349 bool continue_execution() {
350 return mObj != nullptr;
351 }
352
353 private:
354 LifetimeTracker* mObj{nullptr};
355 };
356
makeGuard()357 Guard makeGuard() {
358 return Guard(this);
359 }
360
signalShutdown()361 void signalShutdown() {
362 mReferences.fetch_add(SHUTDOWN_FLAG);
363 }
364
waitCompletion()365 void waitCompletion() {
366 utils::SpinWaitUntilEq(mReferences, SHUTDOWN_FLAG);
367 }
368
369 private:
370 friend class Guard;
371 static constexpr std::uintptr_t SHUTDOWN_FLAG = 1;
372 static constexpr std::uintptr_t REFERENCE_FLAG = 1 << 1;
373 std::atomic<std::uintptr_t> mReferences{};
374 };
375
376 class ArenaTable {
377 static const std::size_t maxArenas = 64;
378 static const std::size_t maxThreads = 1 << 9;
379 static const std::size_t arenaAligment = maxThreads << 1;
380
381 using ArenaPtrRWMutex = PtrRWMutex<tbb::task_arena, arenaAligment>;
382 std::array<ArenaPtrRWMutex, maxArenas> mArenaTable;
383
384 struct ThreadState {
385 bool lockedArenas[maxArenas]{};
386 int arenaIdxStack[maxArenas];
387 int level{};
388 };
389
390 LifetimeTracker mLifetimeTracker{};
391 static thread_local ThreadState mThreadState;
392
393 template <typename F>
find_arena(std::size_t start,F f)394 auto find_arena(std::size_t start, F f) -> decltype(f(std::declval<ArenaPtrRWMutex&>(), std::size_t{})) {
395 for (std::size_t idx = start, i = 0; i < maxArenas; ++i, idx = (idx + 1) % maxArenas) {
396 auto res = f(mArenaTable[idx], idx);
397 if (res) {
398 return res;
399 }
400 }
401 return {};
402 }
403
404 public:
405 using ScopedLock = ArenaPtrRWMutex::ScopedLock;
406
create(Random & rnd)407 void create(Random& rnd) {
408 auto guard = mLifetimeTracker.makeGuard();
409 if (guard.continue_execution()) {
410 int num_threads = rnd.get() % utils::get_platform_max_threads() + 1;
411 unsigned int num_reserved = rnd.get() % num_threads;
412 tbb::task_arena::priority priorities[] = { tbb::task_arena::priority::low , tbb::task_arena::priority::normal, tbb::task_arena::priority::high };
413 tbb::task_arena::priority priority = priorities[rnd.get() % 3];
414
415 tbb::task_arena* a = new (aligned_malloc(arenaAligment, arenaAligment)) tbb::task_arena{ num_threads , num_reserved , priority };
416
417 if (!find_arena(rnd.get() % maxArenas, [a](ArenaPtrRWMutex& arena, std::size_t) -> bool {
418 if (arena.trySet(a)) {
419 return true;
420 }
421 return false;
422 }))
423 {
424 gStats.notify(Statistics::skippedArenaCreate);
425 a->~task_arena();
426 aligned_free(a);
427 }
428 }
429 }
430
destroy(Random & rnd)431 void destroy(Random& rnd) {
432 auto guard = mLifetimeTracker.makeGuard();
433 if (guard.continue_execution()) {
434 auto& ts = mThreadState;
435 if (!find_arena(rnd.get() % maxArenas, [&ts](ArenaPtrRWMutex& arena, std::size_t idx) {
436 if (!ts.lockedArenas[idx]) {
437 ScopedLock lock;
438 if (lock.tryAcquire(arena, true)) {
439 auto a = arena.get();
440 lock.clear();
441 a->~task_arena();
442 aligned_free(a);
443 return true;
444 }
445 }
446 return false;
447 }))
448 {
449 gStats.notify(Statistics::skippedArenaDestroy);
450 }
451 }
452 }
453
shutdown()454 void shutdown() {
455 mLifetimeTracker.signalShutdown();
456 mLifetimeTracker.waitCompletion();
457 find_arena(0, [](ArenaPtrRWMutex& arena, std::size_t) {
458 if (arena.get()) {
459 ScopedLock lock{ arena, true };
460 auto a = arena.get();
461 lock.clear();
462 a->~task_arena();
463 aligned_free(a);
464 }
465 return false;
466 });
467 }
468
acquire(Random & rnd,ScopedLock & lock)469 std::pair<tbb::task_arena*, std::size_t> acquire(Random& rnd, ScopedLock& lock) {
470 auto guard = mLifetimeTracker.makeGuard();
471
472 tbb::task_arena* a{nullptr};
473 std::size_t resIdx{};
474 if (guard.continue_execution()) {
475 auto& ts = mThreadState;
476 a = find_arena(rnd.get() % maxArenas,
477 [&ts, &lock, &resIdx](ArenaPtrRWMutex& arena, std::size_t idx) -> tbb::task_arena* {
478 if (!ts.lockedArenas[idx]) {
479 if (lock.tryAcquire(arena, false)) {
480 ts.lockedArenas[idx] = true;
481 ts.arenaIdxStack[ts.level++] = int(idx);
482 resIdx = idx;
483 return arena.get();
484 }
485 }
486 return nullptr;
487 });
488 if (!a) {
489 gStats.notify(Statistics::skippedArenaAcquire);
490 }
491 }
492 return { a, resIdx };
493 }
494
release(ScopedLock & lock)495 void release(ScopedLock& lock) {
496 auto& ts = mThreadState;
497 CHECK_FAST(ts.level > 0);
498 auto idx = ts.arenaIdxStack[--ts.level];
499 CHECK_FAST(ts.lockedArenas[idx]);
500 ts.lockedArenas[idx] = false;
501 lock.release();
502 }
503 };
504
505 thread_local ArenaTable::ThreadState ArenaTable::mThreadState;
506
507 static ArenaTable arenaTable;
508 static Random threadRandom;
509
510 enum ACTIONS {
511 arena_create,
512 arena_destroy,
513 arena_action,
514 parallel_algorithm,
515 // TODO:
516 // observer_attach,
517 // observer_detach,
518 // flow_graph,
519 // task_group,
520 // resumable_tasks,
521
522 num_actions
523 };
524
525 void global_actor();
526
527 template <ACTIONS action>
528 struct actor;
529
530 template <>
531 struct actor<arena_create> {
do_itactor532 static void do_it(Random& r) {
533 arenaTable.create(r);
534 }
535 };
536
537 template <>
538 struct actor<arena_destroy> {
do_itactor539 static void do_it(Random& r) {
540 arenaTable.destroy(r);
541 }
542 };
543
544 template <>
545 struct actor<arena_action> {
do_itactor546 static void do_it(Random& r) {
547 static thread_local std::size_t arenaLevel = 0;
548 ArenaTable::ScopedLock lock;
549 auto entry = arenaTable.acquire(r, lock);
550 if (entry.first) {
551 enum arena_actions {
552 arena_execute,
553 arena_enqueue,
554 num_arena_actions
555 };
556 auto process = r.get() % 2;
557 auto body = [process] {
558 if (process) {
559 tbb::detail::d1::wait_context wctx{ 1 };
560 tbb::task_group_context ctx;
561 tbb::this_task_arena::enqueue([&wctx] { wctx.release(); });
562 tbb::detail::d1::wait(wctx, ctx);
563 } else {
564 global_actor();
565 }
566 };
567 switch (r.get() % (16*num_arena_actions)) {
568 case arena_execute:
569 if (entry.second > arenaLevel) {
570 gStats.notify(Statistics::ArenaExecute);
571 auto oldArenaLevel = arenaLevel;
572 arenaLevel = entry.second;
573 entry.first->execute(body);
574 arenaLevel = oldArenaLevel;
575 break;
576 }
577 utils_fallthrough;
578 case arena_enqueue:
579 utils_fallthrough;
580 default:
581 gStats.notify(Statistics::ArenaEnqueue);
582 entry.first->enqueue([] { global_actor(); });
583 break;
584 }
585 arenaTable.release(lock);
586 }
587 }
588 };
589
590 template <>
591 struct actor<parallel_algorithm> {
do_itactor592 static void do_it(Random& rnd) {
593 enum PARTITIONERS {
594 simpl_part,
595 auto_part,
596 aff_part,
597 static_part,
598 num_parts
599 };
600 int sz = rnd.get() % 10000;
601 auto doGlbAction = rnd.get() % 1000 == 42;
602 auto body = [doGlbAction, sz](int i) {
603 if (i == sz / 2 && doGlbAction) {
604 global_actor();
605 }
606 };
607
608 switch (rnd.get() % num_parts) {
609 case simpl_part:
610 tbb::parallel_for(0, sz, body, tbb::simple_partitioner{}); break;
611 case auto_part:
612 tbb::parallel_for(0, sz, body, tbb::auto_partitioner{}); break;
613 case aff_part:
614 {
615 tbb::affinity_partitioner aff;
616 tbb::parallel_for(0, sz, body, aff); break;
617 }
618 case static_part:
619 tbb::parallel_for(0, sz, body, tbb::static_partitioner{}); break;
620 }
621 }
622 };
623
global_actor()624 void global_actor() {
625 static thread_local std::uint64_t localNumActions{};
626
627 while (globalNumActions < maxNumActions) {
628 auto& rnd = threadRandom;
629 switch (rnd.get() % num_actions) {
630 case arena_create: gStats.notify(Statistics::ArenaCreate); actor<arena_create>::do_it(rnd); break;
631 case arena_destroy: gStats.notify(Statistics::ArenaDestroy); actor<arena_destroy>::do_it(rnd); break;
632 case arena_action: gStats.notify(Statistics::ArenaAcquire); actor<arena_action>::do_it(rnd); break;
633 case parallel_algorithm: gStats.notify(Statistics::ParallelAlgorithm); actor<parallel_algorithm>::do_it(rnd); break;
634 }
635
636 if (++localNumActions == 100) {
637 localNumActions = 0;
638 globalNumActions += 100;
639
640 // TODO: Enable statistics
641 // static std::mutex mutex;
642 // std::lock_guard<std::mutex> lock{ mutex };
643 // std::cout << globalNumActions << "\r" << std::flush;
644 }
645 }
646 }
647
648 #if TBB_USE_EXCEPTIONS
649 //! \brief \ref stress
650 TEST_CASE("Stress test with mixing functionality") {
651 // TODO add thread recreation
652 // TODO: Enable statistics
653 tbb::task_scheduler_handle handle{ tbb::attach{} };
654
655 const std::size_t numExtraThreads = 16;
656 utils::SpinBarrier startBarrier{numExtraThreads};
__anonfa5562b60902(std::size_t) 657 utils::NativeParallelFor(numExtraThreads, [&startBarrier](std::size_t) {
658 startBarrier.wait();
659 global_actor();
660 });
661
662 arenaTable.shutdown();
663
664 tbb::finalize(handle);
665
666 // gStats.report();
667 }
668 #endif
669