xref: /oneTBB/test/tbb/test_semaphore.cpp (revision b15aabb3)
1 /*
2     Copyright (c) 2005-2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 //! \file test_semaphore.cpp
18 //! \brief Test for [internal] functionality
19 
20 // Test for counting semaphore
21 #include "common/test.h"
22 #include "common/utils.h"
23 #include "common/spin_barrier.h"
24 #include "tbb/blocked_range.h"
25 #include "tbb/tick_count.h"
26 #include "../../src/tbb/semaphore.h"
27 #include <atomic>
28 #include <vector>
29 
30 using tbb::detail::r1::semaphore;
31 
32 std::atomic<int> pCount;
33 utils::SpinBarrier sBarrier;
34 
35 // Semaphore basis function:
36 //  set semaphore to initial value
37 // see that semaphore only allows that number of threads to be active
38 class Body : utils::NoAssign {
39     const int nIters;
40     semaphore& mySem;
41     std::vector<int>& ourCounts;
42     std::vector<double>& tottime;
43 
44     static constexpr int tickCounts = 1; // millisecond
45     static constexpr int innerWait = 5; // millisecond
46 public:
47     Body( int nThread, int nIter, semaphore& sem,
48           std::vector<int>& our_counts, std::vector<double>& tot_time )
49         : nIters(nIter), mySem(sem), ourCounts(our_counts), tottime(tot_time)
50     {
51         sBarrier.initialize(nThread);
52         pCount = 0;
53     }
54 
55     void operator()( const int tid ) const {
56         sBarrier.wait();
57 
58         for (int i = 0; i < nIters; ++i) {
59             utils::Sleep(tid * tickCounts);
60             tbb::tick_count t0 = tbb::tick_count::now();
61             mySem.P();
62             tbb::tick_count t1 = tbb::tick_count::now();
63             tottime[tid] += (t1 - t0).seconds();
64 
65             int curval = ++pCount;
66             if (curval > ourCounts[tid]) {
67                 ourCounts[tid] = curval;
68             }
69             utils::Sleep(innerWait);
70             --pCount;
71             REQUIRE(int(pCount) >= 0);
72             mySem.V();
73         }
74     }
75 }; // class Body
76 
77 void test_semaphore( int sem_init_cnt, int extra_threads ) {
78     semaphore my_sem(sem_init_cnt);
79     int n_threads = sem_init_cnt + extra_threads;
80 
81     std::vector<int> max_vals(n_threads);
82     std::vector<double> tot_times(n_threads);
83 
84     int n_iters = 10;
85     Body body(n_threads, n_iters, my_sem, max_vals, tot_times);
86 
87     pCount = 0;
88     utils::NativeParallelFor(n_threads, body);
89 
90     if (extra_threads == 0) {
91         double allPWaits = 0;
92         for (auto item : tot_times) {
93             allPWaits += item;
94         }
95         allPWaits /= static_cast<double>(n_threads * n_iters);
96     }
97     REQUIRE_MESSAGE(!pCount, "not all threads decremented pCount");
98 
99     int max_count = -1;
100     for (auto item : max_vals) {
101         max_count = utils::max(max_count, item);
102     }
103     REQUIRE_MESSAGE(max_count <= sem_init_cnt, "Too many threads in semaphore-protected increment");
104 }
105 
106 #include "../../src/tbb/semaphore.cpp"
107 #if _WIN32 || _WIN64
108 #include "../../src/tbb/dynamic_link.cpp"
109 #endif
110 
111 constexpr std::size_t N_TIMES = 1000;
112 
113 template <typename S>
114 struct Counter {
115     std::atomic<long> value;
116     S my_sem;
117     Counter() : value(0) {}
118 }; // struct Counter
119 
120 // Function object for use with parallel_for.h
121 template <typename C>
122 struct AddOne : utils::NoAssign {
123     C& my_counter;
124 
125     // Increments counter once for each iteration in the iteration space
126     void operator()( int ) const {
127         for (std::size_t i = 0; i < N_TIMES; ++i) {
128             my_counter.my_sem.P();
129             ++my_counter.value;
130             my_counter.my_sem.V();
131         }
132     }
133 
134     AddOne( C& c ) : my_counter(c) {
135         my_counter.my_sem.V();
136     }
137 }; // struct AddOne
138 
139 void test_binary_semaphore( int n_threads ) {
140     Counter<tbb::detail::r1::binary_semaphore> counter;
141     AddOne<decltype(counter)> AddOneBody(counter);
142     utils::NativeParallelFor(n_threads, AddOneBody);
143     REQUIRE_MESSAGE(n_threads * N_TIMES == counter.value, "Binary semaphore operations P()/V() have a race");
144 }
145 
146 // Power of 2, the most tokens that can be in flight
147 constexpr std::size_t MAX_TOKENS = 32;
148 enum FilterType { imaProducer, imaConsumer };
149 
150 class FilterBase : utils::NoAssign {
151 protected:
152     FilterType ima;
153     unsigned totTokens; // total number of tokens to be emitted, only used by producer
154     std::atomic<unsigned>& myTokens;
155     std::atomic<unsigned>& otherTokens;
156 
157     unsigned myWait;
158     semaphore& my_sem;
159     semaphore& next_sem;
160 
161     unsigned* myBuffer;
162     unsigned* nextBuffer;
163     unsigned curToken;
164 public:
165     FilterBase( FilterType filter,
166                 unsigned tot_tokens,
167                 std::atomic<unsigned>& my_tokens,
168                 std::atomic<unsigned>& other_tokens,
169                 unsigned my_wait,
170                 semaphore& m_sem,
171                 semaphore& n_sem,
172                 unsigned* buf,
173                 unsigned* n_buf )
174         : ima(filter), totTokens(tot_tokens), myTokens(my_tokens),
175           otherTokens(other_tokens), myWait(my_wait), my_sem(m_sem),
176           next_sem(n_sem), myBuffer(buf), nextBuffer(n_buf)
177     {
178         curToken = 0;
179     }
180 
181     void Produce( const int );
182     void Consume( const int );
183     void operator()( const int tid ) {
184         if (ima == imaConsumer) {
185             Consume(tid);
186         } else {
187             Produce(tid);
188         }
189     }
190 }; // class FilterBase
191 
192 class ProduceConsumeBody {
193     FilterBase** my_filters;
194 public:
195     ProduceConsumeBody( FilterBase** filters ) : my_filters(filters) {}
196 
197     void operator()( const int tid ) const {
198         my_filters[tid]->operator()(tid);
199     }
200 }; // class ProduceConsumeBody
201 
202 // send a bunch of non-null "tokens" to consumer, then a NULL
203 void FilterBase::Produce( const int ) {
204     nextBuffer[0] = 0; // just in case we provide no tokens
205     sBarrier.wait();
206     while(totTokens) {
207         while(!myTokens) {
208             my_sem.P();
209         }
210         // we have a slot available
211         --myTokens; // moving this down reduces spurious wakeups
212         --totTokens;
213         if (totTokens) {
214             nextBuffer[curToken & (MAX_TOKENS - 1)] = curToken * 3 + 1;
215         } else {
216             nextBuffer[curToken & (MAX_TOKENS - 1)] = 0;
217         }
218         ++curToken;
219 
220         utils::Sleep(myWait);
221         unsigned temp = ++otherTokens;
222         if (temp == 1) {
223             next_sem.V();
224         }
225     }
226     next_sem.V(); // final wakeup
227 }
228 
229 void FilterBase::Consume( const int ) {
230     unsigned myToken;
231     sBarrier.wait();
232     do {
233         while( !myTokens ) {
234             my_sem.P();
235         }
236         // we have a slot available
237         --myTokens;
238         myToken = myBuffer[curToken & (MAX_TOKENS - 1)];
239         if (myToken) {
240             REQUIRE_MESSAGE(myToken == curToken * 3 + 1, "Error in received token");
241             ++curToken;
242             utils::Sleep(myWait);
243             unsigned temp = ++otherTokens;
244             if (temp == 1) {
245                 next_sem.V();
246             }
247         }
248     } while(myToken);
249     // end of processing
250     REQUIRE_MESSAGE(curToken + 1 == totTokens, "Didn't receive enough tokens");
251 }
252 
253 // test of producer/consumer with atomic buffer cnt and semaphore
254 // nTokens are total number of tokens through the pipe
255 // pWait is the wait time for the producer
256 // cWait is the wait time for the consumer
257 void test_producer_consumer( unsigned totTokens, unsigned nTokens, unsigned pWait, unsigned cWait ) {
258     semaphore p_sem;
259     semaphore c_sem;
260     std::atomic<unsigned> p_tokens;
261     std::atomic<unsigned> c_tokens(0);
262 
263     unsigned c_buffer[MAX_TOKENS];
264     FilterBase* my_filters[2]; // one producer, one concumer
265 
266     REQUIRE_MESSAGE(nTokens <= MAX_TOKENS, "Not enough slots for tokens");
267 
268     my_filters[0] = new FilterBase(imaProducer, totTokens, p_tokens, c_tokens, pWait, c_sem, p_sem, nullptr, &(c_buffer[0]));
269     my_filters[1] = new FilterBase(imaConsumer, totTokens, c_tokens, p_tokens, cWait, p_sem, c_sem, c_buffer, nullptr);
270 
271     p_tokens = nTokens;
272     ProduceConsumeBody body(my_filters);
273     sBarrier.initialize(2);
274     utils::NativeParallelFor(2, body);
275     delete my_filters[0];
276     delete my_filters[1];
277 }
278 
279 //! \brief \ref error_guessing
280 TEST_CASE("test binary semaphore") {
281     test_binary_semaphore(utils::MaxThread);
282 }
283 
284 //! \brief \ref error_guessing
285 TEST_CASE("test semaphore") {
286     for (int sem_size = 1; sem_size <= int(utils::MaxThread); ++sem_size) {
287         for (int ex_threads = 0; ex_threads <= int(utils::MaxThread) - sem_size; ++ex_threads) {
288             test_semaphore(sem_size, ex_threads);
289         }
290     }
291 }
292 
293 //! \brief \ref error_guessing
294 TEST_CASE("test producer-consumer") {
295     test_producer_consumer(10, 2, 5, 5);
296     test_producer_consumer(10, 2, 20, 5);
297     test_producer_consumer(10, 2, 5, 20);
298 
299     test_producer_consumer(10, 1, 5, 5);
300     test_producer_consumer(20, 10, 5, 20);
301     test_producer_consumer(64, 32, 1, 20);
302 }
303