1d86ed7fbStbbdev /*
2*b15aabb3Stbbdev Copyright (c) 2005-2021 Intel Corporation
3d86ed7fbStbbdev
4d86ed7fbStbbdev Licensed under the Apache License, Version 2.0 (the "License");
5d86ed7fbStbbdev you may not use this file except in compliance with the License.
6d86ed7fbStbbdev You may obtain a copy of the License at
7d86ed7fbStbbdev
8d86ed7fbStbbdev http://www.apache.org/licenses/LICENSE-2.0
9d86ed7fbStbbdev
10d86ed7fbStbbdev Unless required by applicable law or agreed to in writing, software
11d86ed7fbStbbdev distributed under the License is distributed on an "AS IS" BASIS,
12d86ed7fbStbbdev WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13d86ed7fbStbbdev See the License for the specific language governing permissions and
14d86ed7fbStbbdev limitations under the License.
15d86ed7fbStbbdev */
16d86ed7fbStbbdev
17d86ed7fbStbbdev // Example program that computes number of prime numbers up to n,
18d86ed7fbStbbdev // where n is a command line argument. The algorithm here is a
19d86ed7fbStbbdev // fairly efficient version of the sieve of Eratosthenes.
20d86ed7fbStbbdev // The parallel version demonstrates how to use parallel_reduce,
21d86ed7fbStbbdev // and in particular how to exploit lazy splitting.
22d86ed7fbStbbdev
23d86ed7fbStbbdev #include <cassert>
24d86ed7fbStbbdev #include <cstdio>
25d86ed7fbStbbdev #include <cstring>
26d86ed7fbStbbdev #include <cmath>
27d86ed7fbStbbdev #include <cstdlib>
28d86ed7fbStbbdev #include <cctype>
29d86ed7fbStbbdev
30d86ed7fbStbbdev #include <algorithm>
31d86ed7fbStbbdev
32d86ed7fbStbbdev #include "oneapi/tbb/parallel_reduce.h"
33d86ed7fbStbbdev #include "oneapi/tbb/global_control.h"
34d86ed7fbStbbdev
35d86ed7fbStbbdev #include "primes.hpp"
36d86ed7fbStbbdev
37d86ed7fbStbbdev //! If true, then print primes on stdout.
38d86ed7fbStbbdev static bool printPrimes = false;
39d86ed7fbStbbdev
40d86ed7fbStbbdev class Multiples {
strike(NumberType start,NumberType limit,NumberType stride)41d86ed7fbStbbdev inline NumberType strike(NumberType start, NumberType limit, NumberType stride) {
42d86ed7fbStbbdev // Hoist "my_is_composite" into register for sake of speed.
43d86ed7fbStbbdev bool* is_composite = my_is_composite;
44d86ed7fbStbbdev assert(stride >= 2);
45d86ed7fbStbbdev for (; start < limit; start += stride)
46d86ed7fbStbbdev is_composite[start] = true;
47d86ed7fbStbbdev return start;
48d86ed7fbStbbdev }
49d86ed7fbStbbdev //! Window into conceptual sieve
50d86ed7fbStbbdev bool* my_is_composite;
51d86ed7fbStbbdev
52d86ed7fbStbbdev //! Indexes into window
53d86ed7fbStbbdev /** my_striker[k] is an index into my_composite corresponding to
54d86ed7fbStbbdev an odd multiple multiple of my_factor[k]. */
55d86ed7fbStbbdev NumberType* my_striker;
56d86ed7fbStbbdev
57d86ed7fbStbbdev //! Prime numbers less than m.
58d86ed7fbStbbdev NumberType* my_factor;
59d86ed7fbStbbdev
60d86ed7fbStbbdev public:
61d86ed7fbStbbdev //! NumberType of factors in my_factor.
62d86ed7fbStbbdev NumberType n_factor;
63d86ed7fbStbbdev NumberType m;
Multiples(NumberType n)64d86ed7fbStbbdev Multiples(NumberType n) {
65d86ed7fbStbbdev m = NumberType(sqrt(double(n)));
66d86ed7fbStbbdev // Round up to even
67d86ed7fbStbbdev m += m & 1;
68d86ed7fbStbbdev my_is_composite = new bool[m / 2];
69d86ed7fbStbbdev my_striker = new NumberType[m / 2];
70d86ed7fbStbbdev my_factor = new NumberType[m / 2];
71d86ed7fbStbbdev n_factor = 0;
72d86ed7fbStbbdev memset(my_is_composite, 0, m / 2);
73d86ed7fbStbbdev for (NumberType i = 3; i < m; i += 2) {
74d86ed7fbStbbdev if (!my_is_composite[i / 2]) {
75d86ed7fbStbbdev if (printPrimes)
76d86ed7fbStbbdev printf("%d\n", (int)i);
77d86ed7fbStbbdev my_striker[n_factor] = strike(i / 2, m / 2, i);
78d86ed7fbStbbdev my_factor[n_factor++] = i;
79d86ed7fbStbbdev }
80d86ed7fbStbbdev }
81d86ed7fbStbbdev }
82d86ed7fbStbbdev
83d86ed7fbStbbdev //! Find primes in range [start,window_size), advancing my_striker as we go.
84d86ed7fbStbbdev /** Returns number of primes found. */
find_primes_in_window(NumberType start,NumberType window_size)85d86ed7fbStbbdev NumberType find_primes_in_window(NumberType start, NumberType window_size) {
86d86ed7fbStbbdev bool* is_composite = my_is_composite;
87d86ed7fbStbbdev memset(is_composite, 0, window_size / 2);
88d86ed7fbStbbdev for (std::size_t k = 0; k < n_factor; ++k)
89d86ed7fbStbbdev my_striker[k] = strike(my_striker[k] - m / 2, window_size / 2, my_factor[k]);
90d86ed7fbStbbdev NumberType count = 0;
91d86ed7fbStbbdev for (NumberType k = 0; k < window_size / 2; ++k) {
92d86ed7fbStbbdev if (!is_composite[k]) {
93d86ed7fbStbbdev if (printPrimes)
94d86ed7fbStbbdev printf("%ld\n", long(start + 2 * k + 1));
95d86ed7fbStbbdev ++count;
96d86ed7fbStbbdev }
97d86ed7fbStbbdev }
98d86ed7fbStbbdev return count;
99d86ed7fbStbbdev }
100d86ed7fbStbbdev
~Multiples()101d86ed7fbStbbdev ~Multiples() {
102d86ed7fbStbbdev delete[] my_factor;
103d86ed7fbStbbdev delete[] my_striker;
104d86ed7fbStbbdev delete[] my_is_composite;
105d86ed7fbStbbdev }
106d86ed7fbStbbdev
107d86ed7fbStbbdev //------------------------------------------------------------------------
108d86ed7fbStbbdev // Begin extra members required by parallel version
109d86ed7fbStbbdev //------------------------------------------------------------------------
110d86ed7fbStbbdev
111d86ed7fbStbbdev // Splitting constructor
Multiples(const Multiples & f,oneapi::tbb::split)112d86ed7fbStbbdev Multiples(const Multiples& f, oneapi::tbb::split)
113d86ed7fbStbbdev : n_factor(f.n_factor),
114d86ed7fbStbbdev m(f.m),
115d86ed7fbStbbdev my_is_composite(nullptr),
116d86ed7fbStbbdev my_striker(nullptr),
117d86ed7fbStbbdev my_factor(f.my_factor) {}
118d86ed7fbStbbdev
is_initialized() const119d86ed7fbStbbdev bool is_initialized() const {
120d86ed7fbStbbdev return my_is_composite != nullptr;
121d86ed7fbStbbdev }
122d86ed7fbStbbdev
initialize(NumberType start)123d86ed7fbStbbdev void initialize(NumberType start) {
124d86ed7fbStbbdev assert(start >= 1);
125d86ed7fbStbbdev my_is_composite = new bool[m / 2];
126d86ed7fbStbbdev my_striker = new NumberType[m / 2];
127d86ed7fbStbbdev for (std::size_t k = 0; k < n_factor; ++k) {
128d86ed7fbStbbdev NumberType f = my_factor[k];
129d86ed7fbStbbdev NumberType p = (start - 1) / f * f % m;
130d86ed7fbStbbdev my_striker[k] = (p & 1 ? p + 2 * f : p + f) / 2;
131d86ed7fbStbbdev assert(m / 2 <= my_striker[k]);
132d86ed7fbStbbdev }
133d86ed7fbStbbdev }
134d86ed7fbStbbdev
135d86ed7fbStbbdev // Move other to *this.
move(Multiples & other)136d86ed7fbStbbdev void move(Multiples& other) {
137d86ed7fbStbbdev // The swap moves the contents of other to *this and causes the old contents
138d86ed7fbStbbdev // of *this to be deleted later when other is destroyed.
139d86ed7fbStbbdev std::swap(my_striker, other.my_striker);
140d86ed7fbStbbdev std::swap(my_is_composite, other.my_is_composite);
141d86ed7fbStbbdev // other.my_factor is a shared pointer that was copied by the splitting constructor.
142d86ed7fbStbbdev // Set it to nullptr to prevent premature deletion by the destructor of ~other.
143d86ed7fbStbbdev assert(my_factor == other.my_factor);
144d86ed7fbStbbdev other.my_factor = nullptr;
145d86ed7fbStbbdev }
146d86ed7fbStbbdev
147d86ed7fbStbbdev //------------------------------------------------------------------------
148d86ed7fbStbbdev // End extra methods required by parallel version
149d86ed7fbStbbdev //------------------------------------------------------------------------
150d86ed7fbStbbdev };
151d86ed7fbStbbdev
152d86ed7fbStbbdev //! Count number of primes between 0 and n
153d86ed7fbStbbdev /** This is the serial version. */
SerialCountPrimes(NumberType n)154d86ed7fbStbbdev NumberType SerialCountPrimes(NumberType n) {
155d86ed7fbStbbdev // Two is special case
156d86ed7fbStbbdev NumberType count = n >= 2;
157d86ed7fbStbbdev if (n >= 3) {
158d86ed7fbStbbdev Multiples multiples(n);
159d86ed7fbStbbdev count += multiples.n_factor;
160d86ed7fbStbbdev if (printPrimes)
161d86ed7fbStbbdev printf("---\n");
162d86ed7fbStbbdev NumberType window_size = multiples.m;
163d86ed7fbStbbdev for (NumberType j = multiples.m; j <= n; j += window_size) {
164d86ed7fbStbbdev if (j + window_size > n + 1)
165d86ed7fbStbbdev window_size = n + 1 - j;
166d86ed7fbStbbdev count += multiples.find_primes_in_window(j, window_size);
167d86ed7fbStbbdev }
168d86ed7fbStbbdev }
169d86ed7fbStbbdev return count;
170d86ed7fbStbbdev }
171d86ed7fbStbbdev
172d86ed7fbStbbdev //! Range of a sieve window.
173d86ed7fbStbbdev class SieveRange {
174d86ed7fbStbbdev //! Width of full-size window into sieve.
175d86ed7fbStbbdev const NumberType my_stride;
176d86ed7fbStbbdev
177d86ed7fbStbbdev //! Always multiple of my_stride
178d86ed7fbStbbdev NumberType my_begin;
179d86ed7fbStbbdev
180d86ed7fbStbbdev //! One past last number in window.
181d86ed7fbStbbdev NumberType my_end;
182d86ed7fbStbbdev
183d86ed7fbStbbdev //! Width above which it is worth forking.
184d86ed7fbStbbdev const NumberType my_grainsize;
185d86ed7fbStbbdev
assert_okay() const186d86ed7fbStbbdev bool assert_okay() const {
187d86ed7fbStbbdev assert(my_begin % my_stride == 0);
188d86ed7fbStbbdev assert(my_begin <= my_end);
189d86ed7fbStbbdev assert(my_stride <= my_grainsize);
190d86ed7fbStbbdev return true;
191d86ed7fbStbbdev }
192d86ed7fbStbbdev
193d86ed7fbStbbdev public:
194d86ed7fbStbbdev //------------------------------------------------------------------------
195d86ed7fbStbbdev // Begin signatures required by parallel_reduce
196d86ed7fbStbbdev //------------------------------------------------------------------------
is_divisible() const197d86ed7fbStbbdev bool is_divisible() const {
198d86ed7fbStbbdev return my_end - my_begin > my_grainsize;
199d86ed7fbStbbdev }
empty() const200d86ed7fbStbbdev bool empty() const {
201d86ed7fbStbbdev return my_end <= my_begin;
202d86ed7fbStbbdev }
SieveRange(SieveRange & r,oneapi::tbb::split)203d86ed7fbStbbdev SieveRange(SieveRange& r, oneapi::tbb::split)
204d86ed7fbStbbdev : my_stride(r.my_stride),
205d86ed7fbStbbdev my_grainsize(r.my_grainsize),
206d86ed7fbStbbdev my_end(r.my_end) {
207d86ed7fbStbbdev assert(r.is_divisible());
208d86ed7fbStbbdev assert(r.assert_okay());
209d86ed7fbStbbdev NumberType middle = r.my_begin + (r.my_end - r.my_begin + r.my_stride - 1) / 2;
210d86ed7fbStbbdev middle = middle / my_stride * my_stride;
211d86ed7fbStbbdev my_begin = middle;
212d86ed7fbStbbdev r.my_end = middle;
213d86ed7fbStbbdev assert(assert_okay());
214d86ed7fbStbbdev assert(r.assert_okay());
215d86ed7fbStbbdev }
216d86ed7fbStbbdev //------------------------------------------------------------------------
217d86ed7fbStbbdev // End of signatures required by parallel_reduce
218d86ed7fbStbbdev //------------------------------------------------------------------------
begin() const219d86ed7fbStbbdev NumberType begin() const {
220d86ed7fbStbbdev return my_begin;
221d86ed7fbStbbdev }
end() const222d86ed7fbStbbdev NumberType end() const {
223d86ed7fbStbbdev return my_end;
224d86ed7fbStbbdev }
SieveRange(NumberType begin,NumberType end,NumberType stride,NumberType grainsize)225d86ed7fbStbbdev SieveRange(NumberType begin, NumberType end, NumberType stride, NumberType grainsize)
226d86ed7fbStbbdev : my_begin(begin),
227d86ed7fbStbbdev my_end(end),
228d86ed7fbStbbdev my_stride(stride),
229d86ed7fbStbbdev my_grainsize(grainsize < stride ? stride : grainsize) {
230d86ed7fbStbbdev assert(assert_okay());
231d86ed7fbStbbdev }
232d86ed7fbStbbdev };
233d86ed7fbStbbdev
234d86ed7fbStbbdev //! Loop body for parallel_reduce.
235d86ed7fbStbbdev /** parallel_reduce splits the sieve into subsieves.
236d86ed7fbStbbdev Each subsieve handles a subrange of [0..n]. */
237d86ed7fbStbbdev class Sieve {
238d86ed7fbStbbdev public:
239d86ed7fbStbbdev //! Prime Multiples to consider, and working storage for this subsieve.
240d86ed7fbStbbdev ::Multiples multiples;
241d86ed7fbStbbdev
242d86ed7fbStbbdev //! NumberType of primes found so far by this subsieve.
243d86ed7fbStbbdev NumberType count;
244d86ed7fbStbbdev
245d86ed7fbStbbdev //! Construct Sieve for counting primes in [0..n].
Sieve(NumberType n)246d86ed7fbStbbdev Sieve(NumberType n) : multiples(n), count(0) {}
247d86ed7fbStbbdev
248d86ed7fbStbbdev //------------------------------------------------------------------------
249d86ed7fbStbbdev // Begin signatures required by parallel_reduce
250d86ed7fbStbbdev //------------------------------------------------------------------------
operator ()(const SieveRange & r)251d86ed7fbStbbdev void operator()(const SieveRange& r) {
252d86ed7fbStbbdev NumberType m = multiples.m;
253d86ed7fbStbbdev if (multiples.is_initialized()) {
254d86ed7fbStbbdev // Simply reuse "Multiples" structure from previous window
255d86ed7fbStbbdev // This works because parallel_reduce always applies
256d86ed7fbStbbdev // *this from left to right.
257d86ed7fbStbbdev }
258d86ed7fbStbbdev else {
259d86ed7fbStbbdev // Need to initialize "Multiples" because *this is a forked copy
260d86ed7fbStbbdev // that needs to be set up to start at r.begin().
261d86ed7fbStbbdev multiples.initialize(r.begin());
262d86ed7fbStbbdev }
263d86ed7fbStbbdev NumberType window_size = m;
264d86ed7fbStbbdev for (NumberType j = r.begin(); j < r.end(); j += window_size) {
265d86ed7fbStbbdev assert(j % multiples.m == 0);
266d86ed7fbStbbdev if (j + window_size > r.end())
267d86ed7fbStbbdev window_size = r.end() - j;
268d86ed7fbStbbdev count += multiples.find_primes_in_window(j, window_size);
269d86ed7fbStbbdev }
270d86ed7fbStbbdev }
join(Sieve & other)271d86ed7fbStbbdev void join(Sieve& other) {
272d86ed7fbStbbdev count += other.count;
273d86ed7fbStbbdev // Final value of multiples needs to final value of other multiples,
274d86ed7fbStbbdev // so that *this can correctly process next window to right.
275d86ed7fbStbbdev multiples.move(other.multiples);
276d86ed7fbStbbdev }
Sieve(Sieve & other,oneapi::tbb::split)277d86ed7fbStbbdev Sieve(Sieve& other, oneapi::tbb::split)
278d86ed7fbStbbdev : multiples(other.multiples, oneapi::tbb::split()),
279d86ed7fbStbbdev count(0) {}
280d86ed7fbStbbdev //------------------------------------------------------------------------
281d86ed7fbStbbdev // End of signatures required by parallel_reduce
282d86ed7fbStbbdev //------------------------------------------------------------------------
283d86ed7fbStbbdev };
284d86ed7fbStbbdev
285d86ed7fbStbbdev //! Count number of primes between 0 and n
286d86ed7fbStbbdev /** This is the parallel version. */
ParallelCountPrimes(NumberType n,int number_of_threads,NumberType grain_size)287d86ed7fbStbbdev NumberType ParallelCountPrimes(NumberType n, int number_of_threads, NumberType grain_size) {
288d86ed7fbStbbdev oneapi::tbb::global_control c(oneapi::tbb::global_control::max_allowed_parallelism,
289d86ed7fbStbbdev number_of_threads);
290d86ed7fbStbbdev
291d86ed7fbStbbdev // Two is special case
292d86ed7fbStbbdev NumberType count = n >= 2;
293d86ed7fbStbbdev if (n >= 3) {
294d86ed7fbStbbdev Sieve s(n);
295d86ed7fbStbbdev count += s.multiples.n_factor;
296d86ed7fbStbbdev if (printPrimes)
297d86ed7fbStbbdev printf("---\n");
298d86ed7fbStbbdev // Explicit grain size and simple_partitioner() used here instead of automatic grainsize
299d86ed7fbStbbdev // determination because we want SieveRange to be decomposed down to grainSize or smaller.
300d86ed7fbStbbdev // Doing so improves odds that the working set fits in cache when evaluating Sieve::operator().
301d86ed7fbStbbdev oneapi::tbb::parallel_reduce(SieveRange(s.multiples.m, n, s.multiples.m, grain_size),
302d86ed7fbStbbdev s,
303d86ed7fbStbbdev oneapi::tbb::simple_partitioner());
304d86ed7fbStbbdev count += s.count;
305d86ed7fbStbbdev }
306d86ed7fbStbbdev return count;
307d86ed7fbStbbdev }
308