xref: /oneTBB/test/common/concurrency_tracker.h (revision c4568449)
1 /*
2     Copyright (c) 2005-2023 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #ifndef __TBB_test_common_concurrency_tracker_H
18 #define __TBB_test_common_concurrency_tracker_H
19 
20 #include "common/test.h"
21 #include "utils.h"
22 #include "spin_barrier.h"
23 #include "oneapi/tbb/parallel_for.h"
24 
25 #include <mutex>
26 
27 namespace utils {
28 
29 static std::atomic<unsigned>        ctInstantParallelism;
30 static std::atomic<unsigned>        ctPeakParallelism;
31 static thread_local std::uintptr_t  ctNested;
32 
33 class ConcurrencyTracker {
34     bool    m_Outer;
35 
Started()36     static void Started () {
37         unsigned p = ++ctInstantParallelism;
38         unsigned q = ctPeakParallelism;
39         while( q<p ) {
40             ctPeakParallelism.compare_exchange_strong(q, p);
41         }
42     }
43 
Stopped()44     static void Stopped () {
45         //CHECK_MESSAGE ( ctInstantParallelism > 0, "Mismatched call to ConcurrencyTracker::Stopped()" );
46         --ctInstantParallelism;
47     }
48 public:
ConcurrencyTracker()49     ConcurrencyTracker() : m_Outer(false) {
50         std::uintptr_t nested = ctNested;
51         CHECK_FAST(nested <= 1);
52         if ( !ctNested ) {
53             Started();
54             m_Outer = true;
55             ctNested = 1;
56         }
57     }
~ConcurrencyTracker()58     ~ConcurrencyTracker() {
59         if ( m_Outer ) {
60             Stopped();
61 #if __GNUC__
62             // Some GCC versions tries to optimize out this write operation. So we need to perform this cast.
63             static_cast<volatile std::uintptr_t&>(ctNested) = 0;
64 #else
65             ctNested = 0;
66 #endif
67         }
68     }
69 
PeakParallelism()70     static unsigned PeakParallelism() { return ctPeakParallelism; }
InstantParallelism()71     static unsigned InstantParallelism() { return ctInstantParallelism; }
72 
Reset()73     static void Reset() {
74         CHECK_MESSAGE(ctInstantParallelism == 0, "Reset cannot be called when concurrency tracking is underway");
75         ctInstantParallelism = ctPeakParallelism = 0;
76     }
77 }; // ConcurrencyTracker
78 
79 
80 struct ExactConcurrencyLevel : NoCopy {
81 private:
82     SpinBarrier                 *myBarrier;
83 
84     // count unique worker threads
85     mutable std::atomic<int>    myUniqueThreadsCnt;
86     static thread_local int     myUniqueThreads;
87     static std::atomic<int>     myEpoch;
88 
89     mutable std::atomic<size_t> myActiveBodyCnt;
90     // output parameter for parallel_for body to report that max is reached
91     mutable std::atomic<bool>   myReachedMax;
92     // zero timeout means no barrier is used during concurrency level detection
93     const double                myTimeout;
94     const size_t                myConcLevel;
95 
96     static std::mutex global_mutex;
97 
ExactConcurrencyLevelExactConcurrencyLevel98     ExactConcurrencyLevel(double timeout, size_t concLevel) :
99         myBarrier(nullptr),
100         myUniqueThreadsCnt(0), myReachedMax(false),
101         myTimeout(timeout), myConcLevel(concLevel) {
102         myActiveBodyCnt = 0;
103     }
runExactConcurrencyLevel104     bool run() {
105         const int LOOP_ITERS = 100;
106         SpinBarrier barrier((unsigned)myConcLevel, /*throwaway=*/true);
107         if (myTimeout != 0.)
108             myBarrier = &barrier;
109         tbb::parallel_for((size_t)0, myConcLevel*LOOP_ITERS, *this, tbb::simple_partitioner());
110         return myReachedMax;
111     }
112 public:
operatorExactConcurrencyLevel113     void operator()(size_t) const {
114         size_t v = ++myActiveBodyCnt;
115         REQUIRE_MESSAGE(v <= myConcLevel, "Number of active bodies is too high.");
116         if (v == myConcLevel) // record that the max expected concurrency was observed
117             myReachedMax = true;
118         // try to get barrier when 1st time in the thread
119         if (myBarrier) {
120             myBarrier->wait();
121         }
122 
123         if (myUniqueThreads != myEpoch) {
124             ++myUniqueThreadsCnt;
125             myUniqueThreads = myEpoch;
126         }
127         for (int i=0; i<100; i++)
128             tbb::detail::machine_pause(1);
129         --myActiveBodyCnt;
130     }
131 
132     enum Mode {
133         None,
134         // When multiple blocking checks are performed, there might be not enough
135         // concurrency for all of them. Serialize check() calls.
136         Serialize
137     };
138 
139     // check that we have never got more than concLevel threads,
140     // and that in some moment we saw exactly concLevel threads
141     static void check(size_t concLevel, Mode m = None) {
142         ExactConcurrencyLevel o(30., concLevel);
143 
144         bool ok = false;
145         if (m == Serialize) {
146             std::lock_guard<std::mutex> lock(global_mutex);
147             ok = o.run();
148         } else {
149             ok = o.run();
150         }
151         CHECK(ok);
152     }
153 
isEqualExactConcurrencyLevel154     static bool isEqual(size_t concLevel) {
155         ExactConcurrencyLevel o(3., concLevel);
156         return o.run();
157     }
158 
checkLessOrEqualExactConcurrencyLevel159     static void checkLessOrEqual(size_t concLevel) {
160         ++ExactConcurrencyLevel::myEpoch;
161         ExactConcurrencyLevel o(0., concLevel);
162 
163         o.run(); // ignore result, as without a barrier it is not reliable
164         CHECK_MESSAGE(o.myUniqueThreadsCnt<=concLevel, "Too many workers observed.");
165     }
166 };
167 
168 std::mutex ExactConcurrencyLevel::global_mutex;
169 thread_local int ExactConcurrencyLevel::myUniqueThreads;
170 std::atomic<int> ExactConcurrencyLevel::myEpoch;
171 
172 } // namespace Harness
173 
174 #endif /* __TBB_test_common_concurrency_tracker_H */
175