1d86ed7fbStbbdev /*
2*b15aabb3Stbbdev     Copyright (c) 2005-2021 Intel Corporation
3d86ed7fbStbbdev 
4d86ed7fbStbbdev     Licensed under the Apache License, Version 2.0 (the "License");
5d86ed7fbStbbdev     you may not use this file except in compliance with the License.
6d86ed7fbStbbdev     You may obtain a copy of the License at
7d86ed7fbStbbdev 
8d86ed7fbStbbdev         http://www.apache.org/licenses/LICENSE-2.0
9d86ed7fbStbbdev 
10d86ed7fbStbbdev     Unless required by applicable law or agreed to in writing, software
11d86ed7fbStbbdev     distributed under the License is distributed on an "AS IS" BASIS,
12d86ed7fbStbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13d86ed7fbStbbdev     See the License for the specific language governing permissions and
14d86ed7fbStbbdev     limitations under the License.
15d86ed7fbStbbdev */
16d86ed7fbStbbdev 
17d86ed7fbStbbdev // Example program that computes number of prime numbers up to n,
18d86ed7fbStbbdev // where n is a command line argument.  The algorithm here is a
19d86ed7fbStbbdev // fairly efficient version of the sieve of Eratosthenes.
20d86ed7fbStbbdev // The parallel version demonstrates how to use parallel_reduce,
21d86ed7fbStbbdev // and in particular how to exploit lazy splitting.
22d86ed7fbStbbdev 
23d86ed7fbStbbdev #include <cassert>
24d86ed7fbStbbdev #include <cstdio>
25d86ed7fbStbbdev #include <cstring>
26d86ed7fbStbbdev #include <cmath>
27d86ed7fbStbbdev #include <cstdlib>
28d86ed7fbStbbdev #include <cctype>
29d86ed7fbStbbdev 
30d86ed7fbStbbdev #include <algorithm>
31d86ed7fbStbbdev 
32d86ed7fbStbbdev #include "oneapi/tbb/parallel_reduce.h"
33d86ed7fbStbbdev #include "oneapi/tbb/global_control.h"
34d86ed7fbStbbdev 
35d86ed7fbStbbdev #include "primes.hpp"
36d86ed7fbStbbdev 
37d86ed7fbStbbdev //! If true, then print primes on stdout.
38d86ed7fbStbbdev static bool printPrimes = false;
39d86ed7fbStbbdev 
40d86ed7fbStbbdev class Multiples {
strike(NumberType start,NumberType limit,NumberType stride)41d86ed7fbStbbdev     inline NumberType strike(NumberType start, NumberType limit, NumberType stride) {
42d86ed7fbStbbdev         // Hoist "my_is_composite" into register for sake of speed.
43d86ed7fbStbbdev         bool* is_composite = my_is_composite;
44d86ed7fbStbbdev         assert(stride >= 2);
45d86ed7fbStbbdev         for (; start < limit; start += stride)
46d86ed7fbStbbdev             is_composite[start] = true;
47d86ed7fbStbbdev         return start;
48d86ed7fbStbbdev     }
49d86ed7fbStbbdev     //! Window into conceptual sieve
50d86ed7fbStbbdev     bool* my_is_composite;
51d86ed7fbStbbdev 
52d86ed7fbStbbdev     //! Indexes into window
53d86ed7fbStbbdev     /** my_striker[k] is an index into my_composite corresponding to
54d86ed7fbStbbdev         an odd multiple multiple of my_factor[k]. */
55d86ed7fbStbbdev     NumberType* my_striker;
56d86ed7fbStbbdev 
57d86ed7fbStbbdev     //! Prime numbers less than m.
58d86ed7fbStbbdev     NumberType* my_factor;
59d86ed7fbStbbdev 
60d86ed7fbStbbdev public:
61d86ed7fbStbbdev     //! NumberType of factors in my_factor.
62d86ed7fbStbbdev     NumberType n_factor;
63d86ed7fbStbbdev     NumberType m;
Multiples(NumberType n)64d86ed7fbStbbdev     Multiples(NumberType n) {
65d86ed7fbStbbdev         m = NumberType(sqrt(double(n)));
66d86ed7fbStbbdev         // Round up to even
67d86ed7fbStbbdev         m += m & 1;
68d86ed7fbStbbdev         my_is_composite = new bool[m / 2];
69d86ed7fbStbbdev         my_striker = new NumberType[m / 2];
70d86ed7fbStbbdev         my_factor = new NumberType[m / 2];
71d86ed7fbStbbdev         n_factor = 0;
72d86ed7fbStbbdev         memset(my_is_composite, 0, m / 2);
73d86ed7fbStbbdev         for (NumberType i = 3; i < m; i += 2) {
74d86ed7fbStbbdev             if (!my_is_composite[i / 2]) {
75d86ed7fbStbbdev                 if (printPrimes)
76d86ed7fbStbbdev                     printf("%d\n", (int)i);
77d86ed7fbStbbdev                 my_striker[n_factor] = strike(i / 2, m / 2, i);
78d86ed7fbStbbdev                 my_factor[n_factor++] = i;
79d86ed7fbStbbdev             }
80d86ed7fbStbbdev         }
81d86ed7fbStbbdev     }
82d86ed7fbStbbdev 
83d86ed7fbStbbdev     //! Find primes in range [start,window_size), advancing my_striker as we go.
84d86ed7fbStbbdev     /** Returns number of primes found. */
find_primes_in_window(NumberType start,NumberType window_size)85d86ed7fbStbbdev     NumberType find_primes_in_window(NumberType start, NumberType window_size) {
86d86ed7fbStbbdev         bool* is_composite = my_is_composite;
87d86ed7fbStbbdev         memset(is_composite, 0, window_size / 2);
88d86ed7fbStbbdev         for (std::size_t k = 0; k < n_factor; ++k)
89d86ed7fbStbbdev             my_striker[k] = strike(my_striker[k] - m / 2, window_size / 2, my_factor[k]);
90d86ed7fbStbbdev         NumberType count = 0;
91d86ed7fbStbbdev         for (NumberType k = 0; k < window_size / 2; ++k) {
92d86ed7fbStbbdev             if (!is_composite[k]) {
93d86ed7fbStbbdev                 if (printPrimes)
94d86ed7fbStbbdev                     printf("%ld\n", long(start + 2 * k + 1));
95d86ed7fbStbbdev                 ++count;
96d86ed7fbStbbdev             }
97d86ed7fbStbbdev         }
98d86ed7fbStbbdev         return count;
99d86ed7fbStbbdev     }
100d86ed7fbStbbdev 
~Multiples()101d86ed7fbStbbdev     ~Multiples() {
102d86ed7fbStbbdev         delete[] my_factor;
103d86ed7fbStbbdev         delete[] my_striker;
104d86ed7fbStbbdev         delete[] my_is_composite;
105d86ed7fbStbbdev     }
106d86ed7fbStbbdev 
107d86ed7fbStbbdev     //------------------------------------------------------------------------
108d86ed7fbStbbdev     // Begin extra members required by parallel version
109d86ed7fbStbbdev     //------------------------------------------------------------------------
110d86ed7fbStbbdev 
111d86ed7fbStbbdev     // Splitting constructor
Multiples(const Multiples & f,oneapi::tbb::split)112d86ed7fbStbbdev     Multiples(const Multiples& f, oneapi::tbb::split)
113d86ed7fbStbbdev             : n_factor(f.n_factor),
114d86ed7fbStbbdev               m(f.m),
115d86ed7fbStbbdev               my_is_composite(nullptr),
116d86ed7fbStbbdev               my_striker(nullptr),
117d86ed7fbStbbdev               my_factor(f.my_factor) {}
118d86ed7fbStbbdev 
is_initialized() const119d86ed7fbStbbdev     bool is_initialized() const {
120d86ed7fbStbbdev         return my_is_composite != nullptr;
121d86ed7fbStbbdev     }
122d86ed7fbStbbdev 
initialize(NumberType start)123d86ed7fbStbbdev     void initialize(NumberType start) {
124d86ed7fbStbbdev         assert(start >= 1);
125d86ed7fbStbbdev         my_is_composite = new bool[m / 2];
126d86ed7fbStbbdev         my_striker = new NumberType[m / 2];
127d86ed7fbStbbdev         for (std::size_t k = 0; k < n_factor; ++k) {
128d86ed7fbStbbdev             NumberType f = my_factor[k];
129d86ed7fbStbbdev             NumberType p = (start - 1) / f * f % m;
130d86ed7fbStbbdev             my_striker[k] = (p & 1 ? p + 2 * f : p + f) / 2;
131d86ed7fbStbbdev             assert(m / 2 <= my_striker[k]);
132d86ed7fbStbbdev         }
133d86ed7fbStbbdev     }
134d86ed7fbStbbdev 
135d86ed7fbStbbdev     // Move other to *this.
move(Multiples & other)136d86ed7fbStbbdev     void move(Multiples& other) {
137d86ed7fbStbbdev         // The swap moves the contents of other to *this and causes the old contents
138d86ed7fbStbbdev         // of *this to be deleted later when other is destroyed.
139d86ed7fbStbbdev         std::swap(my_striker, other.my_striker);
140d86ed7fbStbbdev         std::swap(my_is_composite, other.my_is_composite);
141d86ed7fbStbbdev         // other.my_factor is a shared pointer that was copied by the splitting constructor.
142d86ed7fbStbbdev         // Set it to nullptr to prevent premature deletion by the destructor of ~other.
143d86ed7fbStbbdev         assert(my_factor == other.my_factor);
144d86ed7fbStbbdev         other.my_factor = nullptr;
145d86ed7fbStbbdev     }
146d86ed7fbStbbdev 
147d86ed7fbStbbdev     //------------------------------------------------------------------------
148d86ed7fbStbbdev     // End extra methods required by parallel version
149d86ed7fbStbbdev     //------------------------------------------------------------------------
150d86ed7fbStbbdev };
151d86ed7fbStbbdev 
152d86ed7fbStbbdev //! Count number of primes between 0 and n
153d86ed7fbStbbdev /** This is the serial version. */
SerialCountPrimes(NumberType n)154d86ed7fbStbbdev NumberType SerialCountPrimes(NumberType n) {
155d86ed7fbStbbdev     // Two is special case
156d86ed7fbStbbdev     NumberType count = n >= 2;
157d86ed7fbStbbdev     if (n >= 3) {
158d86ed7fbStbbdev         Multiples multiples(n);
159d86ed7fbStbbdev         count += multiples.n_factor;
160d86ed7fbStbbdev         if (printPrimes)
161d86ed7fbStbbdev             printf("---\n");
162d86ed7fbStbbdev         NumberType window_size = multiples.m;
163d86ed7fbStbbdev         for (NumberType j = multiples.m; j <= n; j += window_size) {
164d86ed7fbStbbdev             if (j + window_size > n + 1)
165d86ed7fbStbbdev                 window_size = n + 1 - j;
166d86ed7fbStbbdev             count += multiples.find_primes_in_window(j, window_size);
167d86ed7fbStbbdev         }
168d86ed7fbStbbdev     }
169d86ed7fbStbbdev     return count;
170d86ed7fbStbbdev }
171d86ed7fbStbbdev 
172d86ed7fbStbbdev //! Range of a sieve window.
173d86ed7fbStbbdev class SieveRange {
174d86ed7fbStbbdev     //! Width of full-size window into sieve.
175d86ed7fbStbbdev     const NumberType my_stride;
176d86ed7fbStbbdev 
177d86ed7fbStbbdev     //! Always multiple of my_stride
178d86ed7fbStbbdev     NumberType my_begin;
179d86ed7fbStbbdev 
180d86ed7fbStbbdev     //! One past last number in window.
181d86ed7fbStbbdev     NumberType my_end;
182d86ed7fbStbbdev 
183d86ed7fbStbbdev     //! Width above which it is worth forking.
184d86ed7fbStbbdev     const NumberType my_grainsize;
185d86ed7fbStbbdev 
assert_okay() const186d86ed7fbStbbdev     bool assert_okay() const {
187d86ed7fbStbbdev         assert(my_begin % my_stride == 0);
188d86ed7fbStbbdev         assert(my_begin <= my_end);
189d86ed7fbStbbdev         assert(my_stride <= my_grainsize);
190d86ed7fbStbbdev         return true;
191d86ed7fbStbbdev     }
192d86ed7fbStbbdev 
193d86ed7fbStbbdev public:
194d86ed7fbStbbdev     //------------------------------------------------------------------------
195d86ed7fbStbbdev     // Begin signatures required by parallel_reduce
196d86ed7fbStbbdev     //------------------------------------------------------------------------
is_divisible() const197d86ed7fbStbbdev     bool is_divisible() const {
198d86ed7fbStbbdev         return my_end - my_begin > my_grainsize;
199d86ed7fbStbbdev     }
empty() const200d86ed7fbStbbdev     bool empty() const {
201d86ed7fbStbbdev         return my_end <= my_begin;
202d86ed7fbStbbdev     }
SieveRange(SieveRange & r,oneapi::tbb::split)203d86ed7fbStbbdev     SieveRange(SieveRange& r, oneapi::tbb::split)
204d86ed7fbStbbdev             : my_stride(r.my_stride),
205d86ed7fbStbbdev               my_grainsize(r.my_grainsize),
206d86ed7fbStbbdev               my_end(r.my_end) {
207d86ed7fbStbbdev         assert(r.is_divisible());
208d86ed7fbStbbdev         assert(r.assert_okay());
209d86ed7fbStbbdev         NumberType middle = r.my_begin + (r.my_end - r.my_begin + r.my_stride - 1) / 2;
210d86ed7fbStbbdev         middle = middle / my_stride * my_stride;
211d86ed7fbStbbdev         my_begin = middle;
212d86ed7fbStbbdev         r.my_end = middle;
213d86ed7fbStbbdev         assert(assert_okay());
214d86ed7fbStbbdev         assert(r.assert_okay());
215d86ed7fbStbbdev     }
216d86ed7fbStbbdev     //------------------------------------------------------------------------
217d86ed7fbStbbdev     // End of signatures required by parallel_reduce
218d86ed7fbStbbdev     //------------------------------------------------------------------------
begin() const219d86ed7fbStbbdev     NumberType begin() const {
220d86ed7fbStbbdev         return my_begin;
221d86ed7fbStbbdev     }
end() const222d86ed7fbStbbdev     NumberType end() const {
223d86ed7fbStbbdev         return my_end;
224d86ed7fbStbbdev     }
SieveRange(NumberType begin,NumberType end,NumberType stride,NumberType grainsize)225d86ed7fbStbbdev     SieveRange(NumberType begin, NumberType end, NumberType stride, NumberType grainsize)
226d86ed7fbStbbdev             : my_begin(begin),
227d86ed7fbStbbdev               my_end(end),
228d86ed7fbStbbdev               my_stride(stride),
229d86ed7fbStbbdev               my_grainsize(grainsize < stride ? stride : grainsize) {
230d86ed7fbStbbdev         assert(assert_okay());
231d86ed7fbStbbdev     }
232d86ed7fbStbbdev };
233d86ed7fbStbbdev 
234d86ed7fbStbbdev //! Loop body for parallel_reduce.
235d86ed7fbStbbdev /** parallel_reduce splits the sieve into subsieves.
236d86ed7fbStbbdev     Each subsieve handles a subrange of [0..n]. */
237d86ed7fbStbbdev class Sieve {
238d86ed7fbStbbdev public:
239d86ed7fbStbbdev     //! Prime Multiples to consider, and working storage for this subsieve.
240d86ed7fbStbbdev     ::Multiples multiples;
241d86ed7fbStbbdev 
242d86ed7fbStbbdev     //! NumberType of primes found so far by this subsieve.
243d86ed7fbStbbdev     NumberType count;
244d86ed7fbStbbdev 
245d86ed7fbStbbdev     //! Construct Sieve for counting primes in [0..n].
Sieve(NumberType n)246d86ed7fbStbbdev     Sieve(NumberType n) : multiples(n), count(0) {}
247d86ed7fbStbbdev 
248d86ed7fbStbbdev     //------------------------------------------------------------------------
249d86ed7fbStbbdev     // Begin signatures required by parallel_reduce
250d86ed7fbStbbdev     //------------------------------------------------------------------------
operator ()(const SieveRange & r)251d86ed7fbStbbdev     void operator()(const SieveRange& r) {
252d86ed7fbStbbdev         NumberType m = multiples.m;
253d86ed7fbStbbdev         if (multiples.is_initialized()) {
254d86ed7fbStbbdev             // Simply reuse "Multiples" structure from previous window
255d86ed7fbStbbdev             // This works because parallel_reduce always applies
256d86ed7fbStbbdev             // *this from left to right.
257d86ed7fbStbbdev         }
258d86ed7fbStbbdev         else {
259d86ed7fbStbbdev             // Need to initialize "Multiples" because *this is a forked copy
260d86ed7fbStbbdev             // that needs to be set up to start at r.begin().
261d86ed7fbStbbdev             multiples.initialize(r.begin());
262d86ed7fbStbbdev         }
263d86ed7fbStbbdev         NumberType window_size = m;
264d86ed7fbStbbdev         for (NumberType j = r.begin(); j < r.end(); j += window_size) {
265d86ed7fbStbbdev             assert(j % multiples.m == 0);
266d86ed7fbStbbdev             if (j + window_size > r.end())
267d86ed7fbStbbdev                 window_size = r.end() - j;
268d86ed7fbStbbdev             count += multiples.find_primes_in_window(j, window_size);
269d86ed7fbStbbdev         }
270d86ed7fbStbbdev     }
join(Sieve & other)271d86ed7fbStbbdev     void join(Sieve& other) {
272d86ed7fbStbbdev         count += other.count;
273d86ed7fbStbbdev         // Final value of multiples needs to final value of other multiples,
274d86ed7fbStbbdev         // so that *this can correctly process next window to right.
275d86ed7fbStbbdev         multiples.move(other.multiples);
276d86ed7fbStbbdev     }
Sieve(Sieve & other,oneapi::tbb::split)277d86ed7fbStbbdev     Sieve(Sieve& other, oneapi::tbb::split)
278d86ed7fbStbbdev             : multiples(other.multiples, oneapi::tbb::split()),
279d86ed7fbStbbdev               count(0) {}
280d86ed7fbStbbdev     //------------------------------------------------------------------------
281d86ed7fbStbbdev     // End of signatures required by parallel_reduce
282d86ed7fbStbbdev     //------------------------------------------------------------------------
283d86ed7fbStbbdev };
284d86ed7fbStbbdev 
285d86ed7fbStbbdev //! Count number of primes between 0 and n
286d86ed7fbStbbdev /** This is the parallel version. */
ParallelCountPrimes(NumberType n,int number_of_threads,NumberType grain_size)287d86ed7fbStbbdev NumberType ParallelCountPrimes(NumberType n, int number_of_threads, NumberType grain_size) {
288d86ed7fbStbbdev     oneapi::tbb::global_control c(oneapi::tbb::global_control::max_allowed_parallelism,
289d86ed7fbStbbdev                                   number_of_threads);
290d86ed7fbStbbdev 
291d86ed7fbStbbdev     // Two is special case
292d86ed7fbStbbdev     NumberType count = n >= 2;
293d86ed7fbStbbdev     if (n >= 3) {
294d86ed7fbStbbdev         Sieve s(n);
295d86ed7fbStbbdev         count += s.multiples.n_factor;
296d86ed7fbStbbdev         if (printPrimes)
297d86ed7fbStbbdev             printf("---\n");
298d86ed7fbStbbdev         // Explicit grain size and simple_partitioner() used here instead of automatic grainsize
299d86ed7fbStbbdev         // determination because we want SieveRange to be decomposed down to grainSize or smaller.
300d86ed7fbStbbdev         // Doing so improves odds that the working set fits in cache when evaluating Sieve::operator().
301d86ed7fbStbbdev         oneapi::tbb::parallel_reduce(SieveRange(s.multiples.m, n, s.multiples.m, grain_size),
302d86ed7fbStbbdev                                      s,
303d86ed7fbStbbdev                                      oneapi::tbb::simple_partitioner());
304d86ed7fbStbbdev         count += s.count;
305d86ed7fbStbbdev     }
306d86ed7fbStbbdev     return count;
307d86ed7fbStbbdev }
308