1 /*
2     Copyright (c) 2005-2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #include "common/test.h"
18 #include "common/utils.h"
19 
20 #include "oneapi/tbb/parallel_for.h"
21 #include "oneapi/tbb/spin_mutex.h"
22 #include "oneapi/tbb/spin_rw_mutex.h"
23 #include "oneapi/tbb/queuing_mutex.h"
24 #include "oneapi/tbb/queuing_rw_mutex.h"
25 #include "oneapi/tbb/null_mutex.h"
26 #include "oneapi/tbb/null_rw_mutex.h"
27 
28 #include <type_traits>
29 
30 //! \file conformance_mutex.cpp
31 //! \brief Test for [mutex.spin_mutex mutex.spin_rw_mutex mutex.queuing_mutex mutex.queuing_rw_mutex mutex.speculative_spin_mutex mutex.speculative_spin_rw_mutex mutex.null_mutex mutex.null_rw_mutex] specifications
32 
33 template<typename M>
34 struct Counter {
35     using mutex_type = M;
36     M mutex;
37     volatile long value;
38 };
39 
40 //! Generic test of a TBB mutex
41 /** Does not test features specific to reader-writer locks. */
42 template<typename M>
43 void GeneralTest(const char* mutex_name, bool check = true) { // check flag is needed to disable correctness check for null mutexes (for test reusage)
44     const int N = 100000;
45     const int GRAIN = 10000;
46     Counter<M> counter;
47     counter.value = 0;
48 
49     // Stress test to force possible race condition of the counter
50     utils::NativeParallelFor(N, GRAIN, [&] (int i) {
51         if (i & 1) {
52             // Try implicit acquire and explicit release
53             typename M::scoped_lock lock(counter.mutex);
54             counter.value = counter.value + 1;
55             lock.release();
56         } else {
57             // Try explicit acquire and implicit release
58             typename M::scoped_lock lock;
59             lock.acquire(counter.mutex);
60             counter.value = counter.value + 1;
61         }
62     });
63     if (check) {
64         REQUIRE_MESSAGE(counter.value == N, "ERROR for " << mutex_name << ": race is detected");
65     }
66 }
67 
68 //! Test try_acquire functionality of a non-reenterable mutex
69 template<typename M>
70 void TestTryAcquire(const char* mutex_name) {
71     M tested_mutex;
72     typename M::scoped_lock lock_outer;
73     if (lock_outer.try_acquire(tested_mutex)) {
74         lock_outer.release();
75     } else {
76         CHECK_MESSAGE(false, "ERROR for " << mutex_name << ": try_acquire failed though it should not");
77     }
78     {
79         typename M::scoped_lock lock_inner(tested_mutex);
80         CHECK_MESSAGE(!lock_outer.try_acquire(tested_mutex), "ERROR for " << mutex_name << ": try_acquire failed though it should not (1)");
81     }
82     if (lock_outer.try_acquire(tested_mutex)) {
83         lock_outer.release();
84     } else {
85         CHECK_MESSAGE(false, "ERROR for " << mutex_name << ": try_acquire failed though it should not");
86     }
87 }
88 
89 template <>
90 void TestTryAcquire<oneapi::tbb::null_mutex>( const char* mutex_name ) {
91     oneapi::tbb::null_mutex tested_mutex;
92     typename oneapi::tbb::null_mutex::scoped_lock lock(tested_mutex);
93     CHECK_MESSAGE(lock.try_acquire(tested_mutex), "ERROR for " << mutex_name << ": try_acquire failed though it should not");
94     lock.release();
95     CHECK_MESSAGE(lock.try_acquire(tested_mutex), "ERROR for " << mutex_name << ": try_acquire failed though it should not");
96 }
97 
98 //! Test try_acquire functionality of a non-reenterable mutex
99 template<typename M>
100 void TestTryAcquireReader(const char* mutex_name) {
101     M tested_mutex;
102     typename M::scoped_lock lock_outer;
103     if (lock_outer.try_acquire(tested_mutex, false) ) {
104         lock_outer.release();
105     } else {
106         CHECK_MESSAGE(false, "ERROR for " << mutex_name << ": try_acquire failed though it should not");
107     }
108     {
109         typename M::scoped_lock lock_inner(tested_mutex, false); // read lock
110         // try acquire on write
111         CHECK_MESSAGE(!lock_outer.try_acquire(tested_mutex, true), "ERROR for " << mutex_name << ": try_acquire on write succeed though it should not (1)");
112         lock_inner.release();                                    // unlock
113         lock_inner.acquire(tested_mutex, true);                  // write lock
114         // try acquire on read
115         CHECK_MESSAGE(!lock_outer.try_acquire(tested_mutex, false), "ERROR for " << mutex_name << ": try_acquire on read succeed though it should not (2)");
116     }
117     if (lock_outer.try_acquire(tested_mutex, false) ) {
118         lock_outer.release();
119     } else {
120         CHECK_MESSAGE(false, "ERROR for " << mutex_name << ": try_acquire failed though it should not");
121     }
122 }
123 
124 template <>
125 void TestTryAcquireReader<oneapi::tbb::null_rw_mutex>( const char* mutex_name ) {
126     oneapi::tbb::null_rw_mutex tested_mutex;
127     typename oneapi::tbb::null_rw_mutex::scoped_lock lock(tested_mutex, false);
128     CHECK_MESSAGE(lock.try_acquire(tested_mutex, false), "Error for " << mutex_name << ": try_acquire on read failed though it should not");
129     CHECK_MESSAGE(lock.try_acquire(tested_mutex, true), "Error for " << mutex_name << ": try_acquire on write failed though it should not");
130     lock.release();
131     CHECK_MESSAGE(lock.try_acquire(tested_mutex, false), "Error for " << mutex_name << ": try_acquire on read failed though it should not");
132     CHECK_MESSAGE(lock.try_acquire(tested_mutex, true), "Error for " << mutex_name << ": try_acquire on write failed though it should not");
133 }
134 
135 template<typename M, size_t N>
136 struct ArrayCounter {
137     using mutex_type = M;
138     M mutex;
139     long value[N];
140 
141     ArrayCounter() : value{0} {}
142 
143     void increment() {
144         for (size_t k = 0; k < N; ++k) {
145             ++value[k];
146         }
147     }
148 
149     bool value_is(long expected_value) const {
150         for (size_t k = 0; k < N; ++k) {
151             if (value[k] != expected_value) {
152                 return false;
153             }
154         }
155         return true;
156     }
157 };
158 
159 template<typename M, typename Counter>
160 void TestReaderWriterLock_Impl(Counter& counter, typename M::scoped_lock& lock, const std::size_t i, const bool write) {
161     bool okay = true;
162     if (write) {
163         long counter_value = counter.value[0];
164         counter.increment();
165         // Downgrade to reader
166         if (i % 16 == 7) {
167             if (!lock.downgrade_to_reader()) {
168                 // Get the previous value as downgrade with the same lock acquired was failed
169                 counter_value = counter.value[0] - 1;
170             }
171             okay = counter.value_is(counter_value + 1);
172         }
173     } else {
174         okay = counter.value_is(counter.value[0]);
175         // Upgrade to writer
176         if (i % 8 == 3) {
177             long counter_value = counter.value[0];
178             if (!lock.upgrade_to_writer()) {
179                 // Failed to upgrade, reacquiring happened, need to update the value
180                 counter_value = counter.value[0];
181             }
182             counter.increment();
183             okay = counter.value_is(counter_value + 1);
184         }
185     }
186     CHECK_MESSAGE(okay, "Error in read write mutex operations");
187 }
188 
189 //! Shared mutex type test
190 template<typename M>
191 void TestReaderWriterLock(const char* mutex_name) {
192     ArrayCounter<M, 8> counter;
193     const int N = 10000;
194 #if TBB_TEST_LOW_WORKLOAD
195     const int GRAIN = 500;
196 #else
197     const int GRAIN = 100;
198 #endif /* TBB_TEST_LOW_WORKLOAD */
199 
200     // Stress test similar to the general, but with upgrade/downgrade cases
201     utils::NativeParallelFor(N, GRAIN, [&](int i) {
202         //! Every 8th access is a write access
203         const bool write = (i % 8) == 7;
204         if (i & 1) {
205             // Try implicit acquire and explicit release
206             typename M::scoped_lock lock(counter.mutex, write);
207             TestReaderWriterLock_Impl<M, ArrayCounter<M, 8>>(counter, lock, i, write);
208             lock.release();
209         } else {
210             // Try explicit acquire and implicit release
211             typename M::scoped_lock lock;
212             lock.acquire(counter.mutex, write);
213             TestReaderWriterLock_Impl<M, ArrayCounter<M, 8>>(counter, lock, i, write);
214         }
215     });
216     // There is either a writer or a reader upgraded to a writer for each 4th iteration
217     REQUIRE_MESSAGE(counter.value_is(N / 4), "ERROR for " << mutex_name << ": race is detected");
218 }
219 
220 template<typename M>
221 void TestRWStateMultipleChange(const char* mutex_name) {
222     static_assert(M::is_rw_mutex, "Incorrect mutex type");
223 
224     const int N = 1000;
225     const int GRAIN = 100;
226     M mutex;
227     utils::NativeParallelFor(N, GRAIN, [&] (int) {
228         typename M::scoped_lock l(mutex, /*write=*/false);
229         for (int i = 0; i != GRAIN; ++i) {
230             CHECK_MESSAGE(l.downgrade_to_reader(), mutex_name << " downgrade must succeed for read lock");
231         }
232         l.upgrade_to_writer();
233         for (int i = 0; i != GRAIN; ++i) {
234             CHECK_MESSAGE(l.upgrade_to_writer(), mutex_name << " upgrade must succeed for write lock");
235         }
236     });
237 }
238 
239 //! Adaptor for using ISO C++0x style mutex as a TBB-style mutex.
240 template<typename M>
241 class TBB_MutexFromISO_Mutex {
242     M my_iso_mutex;
243 public:
244     typedef TBB_MutexFromISO_Mutex mutex_type;
245 
246     class scoped_lock;
247     friend class scoped_lock;
248 
249     class scoped_lock {
250         mutex_type* my_mutex;
251         bool m_is_writer;
252     public:
253         scoped_lock() : my_mutex(NULL), m_is_writer(false) {}
254         scoped_lock(mutex_type& m) : my_mutex(NULL), m_is_writer(false) {
255             acquire(m);
256         }
257         scoped_lock(mutex_type& m, bool is_writer) : my_mutex(NULL) {
258             acquire(m,is_writer);
259         }
260         void acquire(mutex_type& m) {
261             m_is_writer = true;
262             m.my_iso_mutex.lock();
263             my_mutex = &m;
264         }
265         bool try_acquire(mutex_type& m) {
266             m_is_writer = true;
267             if (m.my_iso_mutex.try_lock()) {
268                 my_mutex = &m;
269                 return true;
270             } else {
271                 return false;
272             }
273         }
274 
275         template<typename Q = M>
276         typename std::enable_if<!Q::is_rw_mutex>::type release() {
277             my_mutex->my_iso_mutex.unlock();
278             my_mutex = NULL;
279         }
280 
281         template<typename Q = M>
282         typename std::enable_if<Q::is_rw_mutex>::type  release() {
283             if (m_is_writer)
284                 my_mutex->my_iso_mutex.unlock();
285             else
286                 my_mutex->my_iso_mutex.unlock_shared();
287             my_mutex = NULL;
288         }
289 
290         // Methods for reader-writer mutex
291         // These methods can be instantiated only if M supports lock_shared() and try_lock_shared().
292 
293         void acquire(mutex_type& m, bool is_writer) {
294             m_is_writer = is_writer;
295             if (is_writer) m.my_iso_mutex.lock();
296             else m.my_iso_mutex.lock_shared();
297             my_mutex = &m;
298         }
299         bool try_acquire(mutex_type& m, bool is_writer) {
300             m_is_writer = is_writer;
301             if (is_writer ? m.my_iso_mutex.try_lock() : m.my_iso_mutex.try_lock_shared()) {
302                 my_mutex = &m;
303                 return true;
304             } else {
305                 return false;
306             }
307         }
308         bool upgrade_to_writer() {
309             if (m_is_writer)
310                 return true;
311             m_is_writer = true;
312             my_mutex->my_iso_mutex.unlock_shared();
313             my_mutex->my_iso_mutex.lock();
314             return false;
315         }
316         bool downgrade_to_reader() {
317             if (!m_is_writer)
318                 return true;
319             m_is_writer = false;
320             my_mutex->my_iso_mutex.unlock();
321             my_mutex->my_iso_mutex.lock_shared();
322             return false;
323         }
324         ~scoped_lock() {
325             if (my_mutex)
326                 release();
327         }
328     };
329 
330     static constexpr bool is_recursive_mutex = M::is_recursive_mutex;
331     static constexpr bool is_rw_mutex = M::is_rw_mutex;
332 };
333 
334 template<typename C>
335 struct NullRecursive: utils::NoAssign {
336     void recurse_till(std::size_t i, std::size_t till) const {
337         if(i == till) {
338             counter.value = counter.value + 1;
339             return;
340         }
341         if(i & 1) {
342             typename C::mutex_type::scoped_lock lock2(counter.mutex);
343             recurse_till(i + 1, till);
344             lock2.release();
345         } else {
346             typename C::mutex_type::scoped_lock lock2;
347             lock2.acquire(counter.mutex);
348             recurse_till(i + 1, till);
349         }
350     }
351 
352     void operator()(oneapi::tbb::blocked_range<std::size_t>& range) const {
353         typename C::mutex_type::scoped_lock lock(counter.mutex);
354         recurse_till(range.begin(), range.end());
355     }
356     NullRecursive(C& counter_) : counter(counter_) {
357         REQUIRE_MESSAGE(is_recursive_mutex, "Null mutex should be a recursive mutex.");
358     }
359     C& counter;
360     bool is_recursive_mutex = C::mutex_type::is_recursive_mutex;
361 };
362 
363 template<typename M>
364 struct NullUpgradeDowngrade: utils::NoAssign {
365     void operator()(oneapi::tbb::blocked_range<std::size_t>& range) const {
366         typename M::scoped_lock lock2;
367         for(std::size_t i = range.begin(); i != range.end(); ++i) {
368             if(i & 1) {
369                 typename M::scoped_lock lock1(my_mutex, true);
370                 if(lock1.downgrade_to_reader() == false) {
371                     REQUIRE_MESSAGE(false, "ERROR for " << mutex_name << ": downgrade should always succeed");
372                 }
373             } else {
374                 lock2.acquire(my_mutex, false);
375                 if(lock2.upgrade_to_writer() == false) {
376                     REQUIRE_MESSAGE(false, "ERROR for " << mutex_name << ": upgrade should always succeed");
377                 }
378                 lock2.release();
379             }
380         }
381     }
382 
383     NullUpgradeDowngrade(M& m_, const char* n_) : my_mutex(m_), mutex_name(n_) {}
384     M& my_mutex;
385     const char* mutex_name;
386 };
387 
388 template<typename M>
389 void TestNullMutex(const char* mutex_name) {
390     INFO(mutex_name);
391     Counter<M> counter;
392     counter.value = 0;
393     const std::size_t n = 100;
394     oneapi::tbb::parallel_for(oneapi::tbb::blocked_range<std::size_t>(0, n, 10), NullRecursive<Counter<M>>(counter));
395     M m;
396     m.lock();
397     REQUIRE(m.try_lock());
398     m.unlock();
399 }
400 
401 template<typename M>
402 void TestNullRWMutex(const char* mutex_name) {
403     const std::size_t n = 100;
404     M m;
405     oneapi::tbb::parallel_for(oneapi::tbb::blocked_range<std::size_t>(0, n, 10), NullUpgradeDowngrade<M>(m, mutex_name));
406     m.lock();
407     REQUIRE(m.try_lock());
408     m.lock_shared();
409     REQUIRE(m.try_lock_shared());
410     m.unlock_shared();
411     m.unlock();
412 }
413 
414 //! Testing Mutex requirements
415 //! \brief \ref interface \ref requirement
416 TEST_CASE("Basic Locable requirement test") {
417     // BasicLockable
418     GeneralTest<oneapi::tbb::spin_mutex>("Spin Mutex");
419     GeneralTest<oneapi::tbb::spin_rw_mutex>("Spin RW Mutex");
420     GeneralTest<oneapi::tbb::queuing_mutex>("Queuing Mutex");
421     GeneralTest<oneapi::tbb::queuing_rw_mutex>("Queuing RW Mutex");
422     GeneralTest<oneapi::tbb::speculative_spin_mutex>("Speculative Spin Mutex");
423     GeneralTest<oneapi::tbb::speculative_spin_rw_mutex>("Speculative Spin RW Mutex");
424     // NullMutexes
425     GeneralTest<oneapi::tbb::null_mutex>("Null Mutex", false);
426     GeneralTest<oneapi::tbb::null_rw_mutex>("Null RW Mutex", false);
427     TestNullMutex<oneapi::tbb::null_mutex>("Null Mutex");
428     TestNullMutex<oneapi::tbb::null_rw_mutex>("Null RW Mutex");
429 }
430 
431 //! \brief \ref interface \ref requirement
432 TEST_CASE("Lockable requirement test") {
433     // Lockable - single threaded try_acquire operations
434     TestTryAcquire<oneapi::tbb::spin_mutex>("Spin Mutex");
435     TestTryAcquire<oneapi::tbb::spin_rw_mutex>("Spin RW Mutex");
436     TestTryAcquire<oneapi::tbb::queuing_mutex>("Queuing Mutex");
437     TestTryAcquire<oneapi::tbb::queuing_rw_mutex>("Queuing RW Mutex");
438     TestTryAcquire<oneapi::tbb::speculative_spin_mutex>("Speculative Spin Mutex");
439     TestTryAcquire<oneapi::tbb::speculative_spin_rw_mutex>("Speculative Spin RW Mutex");
440     TestTryAcquire<oneapi::tbb::null_mutex>("Null Mutex");
441 }
442 
443 //! Testing ReaderWriterMutex requirements
444 //! \brief \ref interface \ref requirement
445 TEST_CASE("Shared mutexes (reader/writer) test") {
446     // General reader writer capabilities + upgrade/downgrade
447     TestReaderWriterLock<oneapi::tbb::spin_rw_mutex>("Spin RW Mutex");
448     TestReaderWriterLock<oneapi::tbb::queuing_rw_mutex>("Queuing RW Mutex");
449     TestReaderWriterLock<oneapi::tbb::speculative_spin_rw_mutex>("Speculative Spin RW Mutex");
450     TestNullRWMutex<oneapi::tbb::null_rw_mutex>("Null RW Mutex");
451     // Single threaded read/write try_acquire operations
452     TestTryAcquireReader<oneapi::tbb::spin_rw_mutex>("Spin RW Mutex");
453     TestTryAcquireReader<oneapi::tbb::queuing_rw_mutex>("Queuing RW Mutex");
454     TestRWStateMultipleChange<oneapi::tbb::spin_rw_mutex>("Spin RW Mutex");
455     TestRWStateMultipleChange<oneapi::tbb::queuing_rw_mutex>("Queuing RW Mutex");
456     TestTryAcquireReader<oneapi::tbb::speculative_spin_rw_mutex>("Speculative Spin RW Mutex");
457     TestRWStateMultipleChange<oneapi::tbb::speculative_spin_rw_mutex>("Speculative Spin RW Mutex");
458     TestTryAcquireReader<oneapi::tbb::null_rw_mutex>("Null RW Mutex");
459 }
460 
461 //! Testing ISO C++ Mutex and Shared Mutex requirements.
462 //! Compatibility with the standard
463 //! \brief \ref interface \ref requirement
464 TEST_CASE("ISO interface test") {
465     GeneralTest<TBB_MutexFromISO_Mutex<oneapi::tbb::spin_mutex> >("ISO Spin Mutex");
466     GeneralTest<TBB_MutexFromISO_Mutex<oneapi::tbb::spin_rw_mutex> >("ISO Spin RW Mutex");
467     TestTryAcquire<TBB_MutexFromISO_Mutex<oneapi::tbb::spin_mutex> >("ISO Spin Mutex");
468     TestTryAcquire<TBB_MutexFromISO_Mutex<oneapi::tbb::spin_rw_mutex> >("ISO Spin RW Mutex");
469     TestTryAcquireReader<TBB_MutexFromISO_Mutex<oneapi::tbb::spin_rw_mutex> >("ISO Spin RW Mutex");
470     TestReaderWriterLock<TBB_MutexFromISO_Mutex<oneapi::tbb::spin_rw_mutex> >("ISO Spin RW Mutex");
471 }
472 
473