1 /*
2     Copyright (c) 2005-2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 // Example program that computes number of prime numbers up to n,
18 // where n is a command line argument.  The algorithm here is a
19 // fairly efficient version of the sieve of Eratosthenes.
20 // The parallel version demonstrates how to use parallel_reduce,
21 // and in particular how to exploit lazy splitting.
22 
23 #include <cassert>
24 #include <cstdio>
25 #include <cstring>
26 #include <cmath>
27 #include <cstdlib>
28 #include <cctype>
29 
30 #include <algorithm>
31 
32 #include "oneapi/tbb/parallel_reduce.h"
33 #include "oneapi/tbb/global_control.h"
34 
35 #include "primes.hpp"
36 
37 //! If true, then print primes on stdout.
38 static bool printPrimes = false;
39 
40 class Multiples {
41     inline NumberType strike(NumberType start, NumberType limit, NumberType stride) {
42         // Hoist "my_is_composite" into register for sake of speed.
43         bool* is_composite = my_is_composite;
44         assert(stride >= 2);
45         for (; start < limit; start += stride)
46             is_composite[start] = true;
47         return start;
48     }
49     //! Window into conceptual sieve
50     bool* my_is_composite;
51 
52     //! Indexes into window
53     /** my_striker[k] is an index into my_composite corresponding to
54         an odd multiple multiple of my_factor[k]. */
55     NumberType* my_striker;
56 
57     //! Prime numbers less than m.
58     NumberType* my_factor;
59 
60 public:
61     //! NumberType of factors in my_factor.
62     NumberType n_factor;
63     NumberType m;
64     Multiples(NumberType n) {
65         m = NumberType(sqrt(double(n)));
66         // Round up to even
67         m += m & 1;
68         my_is_composite = new bool[m / 2];
69         my_striker = new NumberType[m / 2];
70         my_factor = new NumberType[m / 2];
71         n_factor = 0;
72         memset(my_is_composite, 0, m / 2);
73         for (NumberType i = 3; i < m; i += 2) {
74             if (!my_is_composite[i / 2]) {
75                 if (printPrimes)
76                     printf("%d\n", (int)i);
77                 my_striker[n_factor] = strike(i / 2, m / 2, i);
78                 my_factor[n_factor++] = i;
79             }
80         }
81     }
82 
83     //! Find primes in range [start,window_size), advancing my_striker as we go.
84     /** Returns number of primes found. */
85     NumberType find_primes_in_window(NumberType start, NumberType window_size) {
86         bool* is_composite = my_is_composite;
87         memset(is_composite, 0, window_size / 2);
88         for (std::size_t k = 0; k < n_factor; ++k)
89             my_striker[k] = strike(my_striker[k] - m / 2, window_size / 2, my_factor[k]);
90         NumberType count = 0;
91         for (NumberType k = 0; k < window_size / 2; ++k) {
92             if (!is_composite[k]) {
93                 if (printPrimes)
94                     printf("%ld\n", long(start + 2 * k + 1));
95                 ++count;
96             }
97         }
98         return count;
99     }
100 
101     ~Multiples() {
102         delete[] my_factor;
103         delete[] my_striker;
104         delete[] my_is_composite;
105     }
106 
107     //------------------------------------------------------------------------
108     // Begin extra members required by parallel version
109     //------------------------------------------------------------------------
110 
111     // Splitting constructor
112     Multiples(const Multiples& f, oneapi::tbb::split)
113             : n_factor(f.n_factor),
114               m(f.m),
115               my_is_composite(nullptr),
116               my_striker(nullptr),
117               my_factor(f.my_factor) {}
118 
119     bool is_initialized() const {
120         return my_is_composite != nullptr;
121     }
122 
123     void initialize(NumberType start) {
124         assert(start >= 1);
125         my_is_composite = new bool[m / 2];
126         my_striker = new NumberType[m / 2];
127         for (std::size_t k = 0; k < n_factor; ++k) {
128             NumberType f = my_factor[k];
129             NumberType p = (start - 1) / f * f % m;
130             my_striker[k] = (p & 1 ? p + 2 * f : p + f) / 2;
131             assert(m / 2 <= my_striker[k]);
132         }
133     }
134 
135     // Move other to *this.
136     void move(Multiples& other) {
137         // The swap moves the contents of other to *this and causes the old contents
138         // of *this to be deleted later when other is destroyed.
139         std::swap(my_striker, other.my_striker);
140         std::swap(my_is_composite, other.my_is_composite);
141         // other.my_factor is a shared pointer that was copied by the splitting constructor.
142         // Set it to nullptr to prevent premature deletion by the destructor of ~other.
143         assert(my_factor == other.my_factor);
144         other.my_factor = nullptr;
145     }
146 
147     //------------------------------------------------------------------------
148     // End extra methods required by parallel version
149     //------------------------------------------------------------------------
150 };
151 
152 //! Count number of primes between 0 and n
153 /** This is the serial version. */
154 NumberType SerialCountPrimes(NumberType n) {
155     // Two is special case
156     NumberType count = n >= 2;
157     if (n >= 3) {
158         Multiples multiples(n);
159         count += multiples.n_factor;
160         if (printPrimes)
161             printf("---\n");
162         NumberType window_size = multiples.m;
163         for (NumberType j = multiples.m; j <= n; j += window_size) {
164             if (j + window_size > n + 1)
165                 window_size = n + 1 - j;
166             count += multiples.find_primes_in_window(j, window_size);
167         }
168     }
169     return count;
170 }
171 
172 //! Range of a sieve window.
173 class SieveRange {
174     //! Width of full-size window into sieve.
175     const NumberType my_stride;
176 
177     //! Always multiple of my_stride
178     NumberType my_begin;
179 
180     //! One past last number in window.
181     NumberType my_end;
182 
183     //! Width above which it is worth forking.
184     const NumberType my_grainsize;
185 
186     bool assert_okay() const {
187         assert(my_begin % my_stride == 0);
188         assert(my_begin <= my_end);
189         assert(my_stride <= my_grainsize);
190         return true;
191     }
192 
193 public:
194     //------------------------------------------------------------------------
195     // Begin signatures required by parallel_reduce
196     //------------------------------------------------------------------------
197     bool is_divisible() const {
198         return my_end - my_begin > my_grainsize;
199     }
200     bool empty() const {
201         return my_end <= my_begin;
202     }
203     SieveRange(SieveRange& r, oneapi::tbb::split)
204             : my_stride(r.my_stride),
205               my_grainsize(r.my_grainsize),
206               my_end(r.my_end) {
207         assert(r.is_divisible());
208         assert(r.assert_okay());
209         NumberType middle = r.my_begin + (r.my_end - r.my_begin + r.my_stride - 1) / 2;
210         middle = middle / my_stride * my_stride;
211         my_begin = middle;
212         r.my_end = middle;
213         assert(assert_okay());
214         assert(r.assert_okay());
215     }
216     //------------------------------------------------------------------------
217     // End of signatures required by parallel_reduce
218     //------------------------------------------------------------------------
219     NumberType begin() const {
220         return my_begin;
221     }
222     NumberType end() const {
223         return my_end;
224     }
225     SieveRange(NumberType begin, NumberType end, NumberType stride, NumberType grainsize)
226             : my_begin(begin),
227               my_end(end),
228               my_stride(stride),
229               my_grainsize(grainsize < stride ? stride : grainsize) {
230         assert(assert_okay());
231     }
232 };
233 
234 //! Loop body for parallel_reduce.
235 /** parallel_reduce splits the sieve into subsieves.
236     Each subsieve handles a subrange of [0..n]. */
237 class Sieve {
238 public:
239     //! Prime Multiples to consider, and working storage for this subsieve.
240     ::Multiples multiples;
241 
242     //! NumberType of primes found so far by this subsieve.
243     NumberType count;
244 
245     //! Construct Sieve for counting primes in [0..n].
246     Sieve(NumberType n) : multiples(n), count(0) {}
247 
248     //------------------------------------------------------------------------
249     // Begin signatures required by parallel_reduce
250     //------------------------------------------------------------------------
251     void operator()(const SieveRange& r) {
252         NumberType m = multiples.m;
253         if (multiples.is_initialized()) {
254             // Simply reuse "Multiples" structure from previous window
255             // This works because parallel_reduce always applies
256             // *this from left to right.
257         }
258         else {
259             // Need to initialize "Multiples" because *this is a forked copy
260             // that needs to be set up to start at r.begin().
261             multiples.initialize(r.begin());
262         }
263         NumberType window_size = m;
264         for (NumberType j = r.begin(); j < r.end(); j += window_size) {
265             assert(j % multiples.m == 0);
266             if (j + window_size > r.end())
267                 window_size = r.end() - j;
268             count += multiples.find_primes_in_window(j, window_size);
269         }
270     }
271     void join(Sieve& other) {
272         count += other.count;
273         // Final value of multiples needs to final value of other multiples,
274         // so that *this can correctly process next window to right.
275         multiples.move(other.multiples);
276     }
277     Sieve(Sieve& other, oneapi::tbb::split)
278             : multiples(other.multiples, oneapi::tbb::split()),
279               count(0) {}
280     //------------------------------------------------------------------------
281     // End of signatures required by parallel_reduce
282     //------------------------------------------------------------------------
283 };
284 
285 //! Count number of primes between 0 and n
286 /** This is the parallel version. */
287 NumberType ParallelCountPrimes(NumberType n, int number_of_threads, NumberType grain_size) {
288     oneapi::tbb::global_control c(oneapi::tbb::global_control::max_allowed_parallelism,
289                                   number_of_threads);
290 
291     // Two is special case
292     NumberType count = n >= 2;
293     if (n >= 3) {
294         Sieve s(n);
295         count += s.multiples.n_factor;
296         if (printPrimes)
297             printf("---\n");
298         // Explicit grain size and simple_partitioner() used here instead of automatic grainsize
299         // determination because we want SieveRange to be decomposed down to grainSize or smaller.
300         // Doing so improves odds that the working set fits in cache when evaluating Sieve::operator().
301         oneapi::tbb::parallel_reduce(SieveRange(s.multiples.m, n, s.multiples.m, grain_size),
302                                      s,
303                                      oneapi::tbb::simple_partitioner());
304         count += s.count;
305     }
306     return count;
307 }
308