xref: /oneTBB/test/tbb/test_scheduler_mix.cpp (revision 219c4252)
1 /*
2     Copyright (c) 2021-2022 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #include "common/config.h"
18 
19 #include <oneapi/tbb/task_arena.h>
20 #include <oneapi/tbb/concurrent_vector.h>
21 #include <oneapi/tbb/rw_mutex.h>
22 #include <oneapi/tbb/task_group.h>
23 #include <oneapi/tbb/parallel_for.h>
24 
25 #include <oneapi/tbb/global_control.h>
26 
27 #include "common/test.h"
28 #include "common/utils.h"
29 #include "common/utils_concurrency_limit.h"
30 #include "common/spin_barrier.h"
31 
32 #include <stdlib.h> // C11/POSIX aligned_alloc
33 #include <random>
34 
35 
36 //! \file test_scheduler_mix.cpp
37 //! \brief Test for [scheduler.task_arena scheduler.task_scheduler_observer] specification
38 
39 const std::uint64_t maxNumActions = 1 * 100 * 1000;
40 static std::atomic<std::uint64_t> globalNumActions{};
41 
42 //using Random = utils::FastRandom<>;
43 class Random {
44     struct State {
45         std::random_device rd;
46         std::mt19937 gen;
47         std::uniform_int_distribution<> dist;
48 
StateRandom::State49         State() : gen(rd()), dist(0, std::numeric_limits<unsigned short>::max()) {}
50 
getRandom::State51         int get() {
52             return dist(gen);
53         }
54     };
55     static thread_local State* mState;
56     tbb::concurrent_vector<State*> mStateList;
57 public:
~Random()58     ~Random() {
59         for (auto s : mStateList) {
60             delete s;
61         }
62     }
63 
get()64     int get() {
65         auto& s = mState;
66         if (!s) {
67             s = new State;
68             mStateList.push_back(s);
69         }
70         return s->get();
71     }
72 };
73 
74 thread_local Random::State* Random::mState = nullptr;
75 
76 
aligned_malloc(std::size_t alignment,std::size_t size)77 void* aligned_malloc(std::size_t alignment, std::size_t size) {
78 #if _WIN32
79     return _aligned_malloc(size, alignment);
80 #elif __unix__ || __APPLE__
81     void* ptr{};
82     int res = posix_memalign(&ptr, alignment, size);
83     CHECK(res == 0);
84     return ptr;
85 #else
86     return aligned_alloc(alignment, size);
87 #endif
88 }
89 
aligned_free(void * ptr)90 void aligned_free(void* ptr) {
91 #if _WIN32
92     _aligned_free(ptr);
93 #else
94     free(ptr);
95 #endif
96 }
97 
98 template <typename T, std::size_t Alignment>
99 class PtrRWMutex {
100     static const std::size_t maxThreads = (Alignment >> 1) - 1;
101     static const std::uintptr_t READER_MASK = maxThreads;       // 7F..
102     static const std::uintptr_t LOCKED = Alignment - 1;         // FF..
103     static const std::uintptr_t LOCKED_MASK = LOCKED;           // FF..
104     static const std::uintptr_t LOCK_PENDING = READER_MASK + 1; // 80..
105 
106     std::atomic<std::uintptr_t> mState;
107 
pointer()108     T* pointer() {
109         return reinterpret_cast<T*>(state() & ~LOCKED_MASK);
110     }
111 
state()112     std::uintptr_t state() {
113         return mState.load(std::memory_order_relaxed);
114     }
115 
116 public:
117     class ScopedLock {
118     public:
ScopedLock()119         constexpr ScopedLock() : mMutex(nullptr), mIsWriter(false) {}
120         //! Acquire lock on given mutex.
ScopedLock(PtrRWMutex & m,bool write=true)121         ScopedLock(PtrRWMutex& m, bool write = true) : mMutex(nullptr) {
122             CHECK_FAST(write == true);
123             acquire(m);
124         }
125         //! Release lock (if lock is held).
~ScopedLock()126         ~ScopedLock() {
127             if (mMutex) {
128                 release();
129             }
130         }
131         //! No Copy
132         ScopedLock(const ScopedLock&) = delete;
133         ScopedLock& operator=(const ScopedLock&) = delete;
134 
135         //! Acquire lock on given mutex.
acquire(PtrRWMutex & m)136         void acquire(PtrRWMutex& m) {
137             CHECK_FAST(mMutex == nullptr);
138             mIsWriter = true;
139             mMutex = &m;
140             mMutex->lock();
141         }
142 
143         //! Try acquire lock on given mutex.
tryAcquire(PtrRWMutex & m,bool write=true)144         bool tryAcquire(PtrRWMutex& m, bool write = true) {
145             bool succeed = write ? m.tryLock() : m.tryLockShared();
146             if (succeed) {
147                 mMutex = &m;
148                 mIsWriter = write;
149             }
150             return succeed;
151         }
152 
clear()153         void clear() {
154             CHECK_FAST(mMutex != nullptr);
155             CHECK_FAST(mIsWriter);
156             PtrRWMutex* m = mMutex;
157             mMutex = nullptr;
158             m->clear();
159         }
160 
161         //! Release lock.
release()162         void release() {
163             CHECK_FAST(mMutex != nullptr);
164             PtrRWMutex* m = mMutex;
165             mMutex = nullptr;
166 
167             if (mIsWriter) {
168                 m->unlock();
169             }
170             else {
171                 m->unlockShared();
172             }
173         }
174     protected:
175         PtrRWMutex* mMutex{};
176         bool mIsWriter{};
177     };
178 
trySet(T * ptr)179     bool trySet(T* ptr) {
180         auto p = reinterpret_cast<std::uintptr_t>(ptr);
181         CHECK_FAST((p & (Alignment - 1)) == 0);
182         if (!state()) {
183             std::uintptr_t expected = 0;
184             if (mState.compare_exchange_strong(expected, p)) {
185                 return true;
186             }
187         }
188         return false;
189     }
190 
clear()191     void clear() {
192         CHECK_FAST((state() & LOCKED_MASK) == LOCKED);
193         mState.store(0, std::memory_order_relaxed);
194     }
195 
tryLock()196     bool tryLock() {
197         auto v = state();
198         if (v == 0) {
199             return false;
200         }
201         CHECK_FAST((v & LOCKED_MASK) == LOCKED || (v & READER_MASK) < maxThreads);
202         if ((v & READER_MASK) == 0) {
203             if (mState.compare_exchange_strong(v, v | LOCKED)) {
204                 return true;
205             }
206         }
207         return false;
208     }
209 
tryLockShared()210     bool tryLockShared() {
211         auto v = state();
212         if (v == 0) {
213             return false;
214         }
215         CHECK_FAST((v & LOCKED_MASK) == LOCKED || (v & READER_MASK) < maxThreads);
216         if ((v & LOCKED_MASK) != LOCKED && (v & LOCK_PENDING) == 0) {
217             if (mState.compare_exchange_strong(v, v + 1)) {
218                 return true;
219             }
220         }
221         return false;
222     }
223 
lock()224     void lock() {
225         auto v = state();
226         mState.compare_exchange_strong(v, v | LOCK_PENDING);
227         while (!tryLock()) {
228             utils::yield();
229         }
230     }
231 
unlock()232     void unlock() {
233         auto v = state();
234         CHECK_FAST((v & LOCKED_MASK) == LOCKED);
235         mState.store(v & ~LOCKED, std::memory_order_release);
236     }
237 
unlockShared()238     void unlockShared() {
239         auto v = state();
240         CHECK_FAST((v & LOCKED_MASK) != LOCKED);
241         CHECK_FAST((v & READER_MASK) > 0);
242         mState -= 1;
243     }
244 
operator bool() const245     operator bool() const {
246         return pointer() != 0;
247     }
248 
get()249     T* get() {
250         return pointer();
251     }
252 };
253 
254 class Statistics {
255 public:
256     enum ACTION {
257         ArenaCreate,
258         ArenaDestroy,
259         ArenaAcquire,
260         skippedArenaCreate,
261         skippedArenaDestroy,
262         skippedArenaAcquire,
263         ParallelAlgorithm,
264         ArenaEnqueue,
265         ArenaExecute,
266         numActions
267     };
268 
269     static const char* const mStatNames[numActions];
270 private:
271     struct StatType {
StatTypeStatistics::StatType272         StatType() : mCounters() {}
273         std::array<std::uint64_t, numActions> mCounters;
274     };
275 
276     tbb::concurrent_vector<StatType*> mStatsList;
277     static thread_local StatType* mStats;
278 
get()279     StatType& get() {
280         auto& s = mStats;
281         if (!s) {
282             s = new StatType;
283             mStatsList.push_back(s);
284         }
285         return *s;
286     }
287 public:
~Statistics()288     ~Statistics() {
289         for (auto s : mStatsList) {
290             delete s;
291         }
292     }
293 
notify(ACTION a)294     void notify(ACTION a) {
295         ++get().mCounters[a];
296     }
297 
report()298     void report() {
299         StatType summary;
300         for (auto& s : mStatsList) {
301             for (int i = 0; i < numActions; ++i) {
302                 summary.mCounters[i] += s->mCounters[i];
303             }
304         }
305         std::cout << std::endl << "Statistics:" << std::endl;
306         std::cout << "Total actions: " << globalNumActions << std::endl;
307         for (int i = 0; i < numActions; ++i) {
308             std::cout << mStatNames[i] << ": " << summary.mCounters[i] << std::endl;
309         }
310     }
311 };
312 
313 
314 const char* const Statistics::mStatNames[Statistics::numActions] = {
315     "Arena create", "Arena destroy", "Arena acquire",
316     "Skipped arena create", "Skipped arena destroy", "Skipped arena acquire",
317     "Parallel algorithm", "Arena enqueue", "Arena execute"
318 };
319 thread_local Statistics::StatType* Statistics::mStats;
320 
321 static Statistics gStats;
322 
323 class LifetimeTracker {
324 public:
325     LifetimeTracker() = default;
326 
327     class Guard {
328     public:
Guard(LifetimeTracker * obj)329         Guard(LifetimeTracker* obj) {
330             if (!(obj->mReferences.load(std::memory_order_relaxed) & SHUTDOWN_FLAG)) {
331                 if (obj->mReferences.fetch_add(REFERENCE_FLAG) & SHUTDOWN_FLAG) {
332                     obj->mReferences.fetch_sub(REFERENCE_FLAG);
333                 } else {
334                     mObj = obj;
335                 }
336             }
337         }
338 
Guard(Guard && ing)339         Guard(Guard&& ing) : mObj(ing.mObj) {
340             ing.mObj = nullptr;
341         }
342 
~Guard()343         ~Guard() {
344             if (mObj != nullptr) {
345                 mObj->mReferences.fetch_sub(REFERENCE_FLAG);
346             }
347         }
348 
continue_execution()349         bool continue_execution() {
350             return mObj != nullptr;
351         }
352 
353     private:
354         LifetimeTracker* mObj{nullptr};
355     };
356 
makeGuard()357     Guard makeGuard() {
358         return Guard(this);
359     }
360 
signalShutdown()361     void signalShutdown() {
362         mReferences.fetch_add(SHUTDOWN_FLAG);
363     }
364 
waitCompletion()365     void waitCompletion() {
366         utils::SpinWaitUntilEq(mReferences, SHUTDOWN_FLAG);
367     }
368 
369 private:
370     friend class Guard;
371     static constexpr std::uintptr_t SHUTDOWN_FLAG = 1;
372     static constexpr std::uintptr_t REFERENCE_FLAG = 1 << 1;
373     std::atomic<std::uintptr_t> mReferences{};
374 };
375 
376 class ArenaTable {
377     static const std::size_t maxArenas = 64;
378     static const std::size_t maxThreads = 1 << 9;
379     static const std::size_t arenaAligment = maxThreads << 1;
380 
381     using ArenaPtrRWMutex = PtrRWMutex<tbb::task_arena, arenaAligment>;
382     std::array<ArenaPtrRWMutex, maxArenas> mArenaTable;
383 
384     struct ThreadState {
385         bool lockedArenas[maxArenas]{};
386         int arenaIdxStack[maxArenas];
387         int level{};
388     };
389 
390     LifetimeTracker mLifetimeTracker{};
391     static thread_local ThreadState mThreadState;
392 
393     template <typename F>
find_arena(std::size_t start,F f)394     auto find_arena(std::size_t start, F f) -> decltype(f(std::declval<ArenaPtrRWMutex&>(), std::size_t{})) {
395         for (std::size_t idx = start, i = 0; i < maxArenas; ++i, idx = (idx + 1) % maxArenas) {
396             auto res = f(mArenaTable[idx], idx);
397             if (res) {
398                 return res;
399             }
400         }
401         return {};
402     }
403 
404 public:
405     using ScopedLock = ArenaPtrRWMutex::ScopedLock;
406 
create(Random & rnd)407     void create(Random& rnd) {
408         auto guard = mLifetimeTracker.makeGuard();
409         if (guard.continue_execution()) {
410             int num_threads = rnd.get() % utils::get_platform_max_threads() + 1;
411             unsigned int num_reserved = rnd.get() % num_threads;
412             tbb::task_arena::priority priorities[] = { tbb::task_arena::priority::low , tbb::task_arena::priority::normal, tbb::task_arena::priority::high };
413             tbb::task_arena::priority priority = priorities[rnd.get() % 3];
414 
415             tbb::task_arena* a = new (aligned_malloc(arenaAligment, arenaAligment)) tbb::task_arena{ num_threads , num_reserved , priority };
416 
417             if (!find_arena(rnd.get() % maxArenas, [a](ArenaPtrRWMutex& arena, std::size_t) -> bool {
418                     if (arena.trySet(a)) {
419                         return true;
420                     }
421                     return false;
422                 }))
423             {
424                 gStats.notify(Statistics::skippedArenaCreate);
425                 a->~task_arena();
426                 aligned_free(a);
427             }
428         }
429     }
430 
destroy(Random & rnd)431     void destroy(Random& rnd) {
432         auto guard = mLifetimeTracker.makeGuard();
433         if (guard.continue_execution()) {
434             auto& ts = mThreadState;
435             if (!find_arena(rnd.get() % maxArenas, [&ts](ArenaPtrRWMutex& arena, std::size_t idx) {
436                     if (!ts.lockedArenas[idx]) {
437                         ScopedLock lock;
438                         if (lock.tryAcquire(arena, true)) {
439                             auto a = arena.get();
440                             lock.clear();
441                             a->~task_arena();
442                             aligned_free(a);
443                             return true;
444                         }
445                     }
446                     return false;
447                 }))
448             {
449                 gStats.notify(Statistics::skippedArenaDestroy);
450             }
451         }
452     }
453 
shutdown()454     void shutdown() {
455         mLifetimeTracker.signalShutdown();
456         mLifetimeTracker.waitCompletion();
457         find_arena(0, [](ArenaPtrRWMutex& arena, std::size_t) {
458             if (arena.get()) {
459                 ScopedLock lock{ arena, true };
460                 auto a = arena.get();
461                 lock.clear();
462                 a->~task_arena();
463                 aligned_free(a);
464             }
465             return false;
466         });
467     }
468 
acquire(Random & rnd,ScopedLock & lock)469     std::pair<tbb::task_arena*, std::size_t> acquire(Random& rnd, ScopedLock& lock) {
470         auto guard = mLifetimeTracker.makeGuard();
471 
472         tbb::task_arena* a{nullptr};
473         std::size_t resIdx{};
474         if (guard.continue_execution()) {
475             auto& ts = mThreadState;
476             a = find_arena(rnd.get() % maxArenas,
477                 [&ts, &lock, &resIdx](ArenaPtrRWMutex& arena, std::size_t idx) -> tbb::task_arena* {
478                     if (!ts.lockedArenas[idx]) {
479                         if (lock.tryAcquire(arena, false)) {
480                             ts.lockedArenas[idx] = true;
481                             ts.arenaIdxStack[ts.level++] = int(idx);
482                             resIdx = idx;
483                             return arena.get();
484                         }
485                     }
486                     return nullptr;
487                 });
488             if (!a) {
489                 gStats.notify(Statistics::skippedArenaAcquire);
490             }
491         }
492         return { a, resIdx };
493     }
494 
release(ScopedLock & lock)495     void release(ScopedLock& lock) {
496         auto& ts = mThreadState;
497         CHECK_FAST(ts.level > 0);
498         auto idx = ts.arenaIdxStack[--ts.level];
499         CHECK_FAST(ts.lockedArenas[idx]);
500         ts.lockedArenas[idx] = false;
501         lock.release();
502     }
503 };
504 
505 thread_local ArenaTable::ThreadState ArenaTable::mThreadState;
506 
507 static ArenaTable arenaTable;
508 static Random threadRandom;
509 
510 enum ACTIONS {
511     arena_create,
512     arena_destroy,
513     arena_action,
514     parallel_algorithm,
515     // TODO:
516     // observer_attach,
517     // observer_detach,
518     // flow_graph,
519     // task_group,
520     // resumable_tasks,
521 
522     num_actions
523 };
524 
525 void global_actor();
526 
527 template <ACTIONS action>
528 struct actor;
529 
530 template <>
531 struct actor<arena_create> {
do_itactor532     static void do_it(Random& r) {
533         arenaTable.create(r);
534     }
535 };
536 
537 template <>
538 struct actor<arena_destroy> {
do_itactor539     static void do_it(Random& r) {
540         arenaTable.destroy(r);
541     }
542 };
543 
544 template <>
545 struct actor<arena_action> {
do_itactor546     static void do_it(Random& r) {
547         static thread_local std::size_t arenaLevel = 0;
548         ArenaTable::ScopedLock lock;
549         auto entry = arenaTable.acquire(r, lock);
550         if (entry.first) {
551             enum arena_actions {
552                 arena_execute,
553                 arena_enqueue,
554                 num_arena_actions
555             };
556             auto process = r.get() % 2;
557             auto body = [process] {
558                 if (process) {
559                     tbb::detail::d1::wait_context wctx{ 1 };
560                     tbb::task_group_context ctx;
561                     tbb::this_task_arena::enqueue([&wctx] { wctx.release(); });
562                     tbb::detail::d1::wait(wctx, ctx);
563                 } else {
564                     global_actor();
565                 }
566             };
567             switch (r.get() % (16*num_arena_actions)) {
568             case arena_execute:
569                 if (entry.second > arenaLevel) {
570                     gStats.notify(Statistics::ArenaExecute);
571                     auto oldArenaLevel = arenaLevel;
572                     arenaLevel = entry.second;
573                     entry.first->execute(body);
574                     arenaLevel = oldArenaLevel;
575                     break;
576                 }
577                 utils_fallthrough;
578             case arena_enqueue:
579                 utils_fallthrough;
580             default:
581                 gStats.notify(Statistics::ArenaEnqueue);
582                 entry.first->enqueue([] { global_actor(); });
583                 break;
584             }
585             arenaTable.release(lock);
586         }
587     }
588 };
589 
590 template <>
591 struct actor<parallel_algorithm> {
do_itactor592     static void do_it(Random& rnd) {
593         enum PARTITIONERS {
594             simpl_part,
595             auto_part,
596             aff_part,
597             static_part,
598             num_parts
599         };
600         int sz = rnd.get() % 10000;
601         auto doGlbAction = rnd.get() % 1000 == 42;
602         auto body = [doGlbAction, sz](int i) {
603             if (i == sz / 2 && doGlbAction) {
604                 global_actor();
605             }
606         };
607 
608         switch (rnd.get() % num_parts) {
609         case simpl_part:
610             tbb::parallel_for(0, sz, body, tbb::simple_partitioner{}); break;
611         case auto_part:
612             tbb::parallel_for(0, sz, body, tbb::auto_partitioner{}); break;
613         case aff_part:
614         {
615             tbb::affinity_partitioner aff;
616             tbb::parallel_for(0, sz, body, aff); break;
617         }
618         case static_part:
619             tbb::parallel_for(0, sz, body, tbb::static_partitioner{}); break;
620         }
621     }
622 };
623 
global_actor()624 void global_actor() {
625     static thread_local std::uint64_t localNumActions{};
626 
627     while (globalNumActions < maxNumActions) {
628         auto& rnd = threadRandom;
629         switch (rnd.get() % num_actions) {
630         case arena_create:  gStats.notify(Statistics::ArenaCreate); actor<arena_create>::do_it(rnd);  break;
631         case arena_destroy: gStats.notify(Statistics::ArenaDestroy); actor<arena_destroy>::do_it(rnd); break;
632         case arena_action:  gStats.notify(Statistics::ArenaAcquire); actor<arena_action>::do_it(rnd);  break;
633         case parallel_algorithm: gStats.notify(Statistics::ParallelAlgorithm); actor<parallel_algorithm>::do_it(rnd);  break;
634         }
635 
636         if (++localNumActions == 100) {
637             localNumActions = 0;
638             globalNumActions += 100;
639 
640             // TODO: Enable statistics
641             // static std::mutex mutex;
642             // std::lock_guard<std::mutex> lock{ mutex };
643             // std::cout << globalNumActions << "\r" << std::flush;
644         }
645     }
646 }
647 
648 #if TBB_USE_EXCEPTIONS
649 //! \brief \ref stress
650 TEST_CASE("Stress test with mixing functionality") {
651     // TODO add thread recreation
652     // TODO: Enable statistics
653     tbb::task_scheduler_handle handle{ tbb::attach{} };
654 
655     const std::size_t numExtraThreads = 16;
656     utils::SpinBarrier startBarrier{numExtraThreads};
__anonfa5562b60902(std::size_t) 657     utils::NativeParallelFor(numExtraThreads, [&startBarrier](std::size_t) {
658         startBarrier.wait();
659         global_actor();
660     });
661 
662     arenaTable.shutdown();
663 
664     tbb::finalize(handle);
665 
666     // gStats.report();
667 }
668 #endif
669