xref: /oneTBB/include/oneapi/tbb/parallel_sort.h (revision 4eec89fe)
149e08aacStbbdev /*
2b15aabb3Stbbdev     Copyright (c) 2005-2021 Intel Corporation
349e08aacStbbdev 
449e08aacStbbdev     Licensed under the Apache License, Version 2.0 (the "License");
549e08aacStbbdev     you may not use this file except in compliance with the License.
649e08aacStbbdev     You may obtain a copy of the License at
749e08aacStbbdev 
849e08aacStbbdev         http://www.apache.org/licenses/LICENSE-2.0
949e08aacStbbdev 
1049e08aacStbbdev     Unless required by applicable law or agreed to in writing, software
1149e08aacStbbdev     distributed under the License is distributed on an "AS IS" BASIS,
1249e08aacStbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1349e08aacStbbdev     See the License for the specific language governing permissions and
1449e08aacStbbdev     limitations under the License.
1549e08aacStbbdev */
1649e08aacStbbdev 
1749e08aacStbbdev #ifndef __TBB_parallel_sort_H
1849e08aacStbbdev #define __TBB_parallel_sort_H
1949e08aacStbbdev 
2049e08aacStbbdev #include "detail/_namespace_injection.h"
2149e08aacStbbdev #include "parallel_for.h"
2249e08aacStbbdev #include "blocked_range.h"
2349e08aacStbbdev #include "profiling.h"
2449e08aacStbbdev 
2549e08aacStbbdev #include <algorithm>
2649e08aacStbbdev #include <iterator>
2749e08aacStbbdev #include <functional>
2849e08aacStbbdev #include <cstddef>
2949e08aacStbbdev 
3049e08aacStbbdev namespace tbb {
3149e08aacStbbdev namespace detail {
32478de5b1Stbbdev #if __TBB_CPP20_CONCEPTS_PRESENT
33478de5b1Stbbdev inline namespace d0 {
34478de5b1Stbbdev 
35478de5b1Stbbdev // TODO: consider using std::strict_weak_order concept
36478de5b1Stbbdev template <typename Compare, typename Iterator>
requires(const std::remove_reference_t<Compare> & comp,typename std::iterator_traits<Iterator>::reference value)37478de5b1Stbbdev concept compare = requires( const std::remove_reference_t<Compare>& comp, typename std::iterator_traits<Iterator>::reference value ) {
38478de5b1Stbbdev     // Forward via iterator_traits::reference
39478de5b1Stbbdev     { comp(typename std::iterator_traits<Iterator>::reference(value),
40478de5b1Stbbdev            typename std::iterator_traits<Iterator>::reference(value)) } -> std::convertible_to<bool>;
41478de5b1Stbbdev };
42478de5b1Stbbdev 
43478de5b1Stbbdev // Inspired by std::__PartiallyOrderedWith exposition only concept
44478de5b1Stbbdev template <typename T>
requires(const std::remove_reference_t<T> & lhs,const std::remove_reference_t<T> & rhs)45478de5b1Stbbdev concept less_than_comparable = requires( const std::remove_reference_t<T>& lhs,
46478de5b1Stbbdev                                          const std::remove_reference_t<T>& rhs ) {
47478de5b1Stbbdev     { lhs < rhs } -> boolean_testable;
48478de5b1Stbbdev };
49478de5b1Stbbdev 
50478de5b1Stbbdev } // namespace d0
51478de5b1Stbbdev #endif // __TBB_CPP20_CONCEPTS_PRESENT
5249e08aacStbbdev namespace d1 {
5349e08aacStbbdev 
5449e08aacStbbdev //! Range used in quicksort to split elements into subranges based on a value.
5549e08aacStbbdev /** The split operation selects a splitter and places all elements less than or equal
5649e08aacStbbdev     to the value in the first range and the remaining elements in the second range.
5749e08aacStbbdev     @ingroup algorithms */
5849e08aacStbbdev template<typename RandomAccessIterator, typename Compare>
5949e08aacStbbdev class quick_sort_range {
median_of_three(const RandomAccessIterator & array,std::size_t l,std::size_t m,std::size_t r)6049e08aacStbbdev     std::size_t median_of_three( const RandomAccessIterator& array, std::size_t l, std::size_t m, std::size_t r ) const {
6149e08aacStbbdev         return comp(array[l], array[m]) ? ( comp(array[m], array[r]) ? m : ( comp(array[l], array[r]) ? r : l ) )
6249e08aacStbbdev                                         : ( comp(array[r], array[m]) ? m : ( comp(array[r], array[l]) ? r : l ) );
6349e08aacStbbdev     }
6449e08aacStbbdev 
pseudo_median_of_nine(const RandomAccessIterator & array,const quick_sort_range & range)6549e08aacStbbdev     std::size_t pseudo_median_of_nine( const RandomAccessIterator& array, const quick_sort_range& range ) const {
6649e08aacStbbdev         std::size_t offset = range.size / 8u;
6749e08aacStbbdev         return median_of_three(array,
6849e08aacStbbdev                                median_of_three(array, 0 , offset, offset * 2),
6949e08aacStbbdev                                median_of_three(array, offset * 3, offset * 4, offset * 5),
7049e08aacStbbdev                                median_of_three(array, offset * 6, offset * 7, range.size - 1));
7149e08aacStbbdev 
7249e08aacStbbdev     }
7349e08aacStbbdev 
split_range(quick_sort_range & range)7449e08aacStbbdev     std::size_t split_range( quick_sort_range& range ) {
7549e08aacStbbdev         RandomAccessIterator array = range.begin;
7649e08aacStbbdev         RandomAccessIterator first_element = range.begin;
7749e08aacStbbdev         std::size_t m = pseudo_median_of_nine(array, range);
7849e08aacStbbdev         if( m != 0 ) std::iter_swap(array, array + m);
7949e08aacStbbdev 
8049e08aacStbbdev         std::size_t i = 0;
8149e08aacStbbdev         std::size_t j = range.size;
8249e08aacStbbdev         // Partition interval [i + 1,j - 1] with key *first_element.
8349e08aacStbbdev         for(;;) {
8449e08aacStbbdev             __TBB_ASSERT( i < j, nullptr );
8549e08aacStbbdev             // Loop must terminate since array[l] == *first_element.
8649e08aacStbbdev             do {
8749e08aacStbbdev                 --j;
8849e08aacStbbdev                 __TBB_ASSERT( i <= j, "bad ordering relation?" );
8949e08aacStbbdev             } while( comp(*first_element, array[j]) );
9049e08aacStbbdev             do {
9149e08aacStbbdev                 __TBB_ASSERT( i <= j, nullptr );
9249e08aacStbbdev                 if( i == j ) goto partition;
9349e08aacStbbdev                 ++i;
9449e08aacStbbdev             } while( comp(array[i], *first_element) );
9549e08aacStbbdev             if( i == j ) goto partition;
9649e08aacStbbdev             std::iter_swap(array + i, array + j);
9749e08aacStbbdev         }
9849e08aacStbbdev partition:
9949e08aacStbbdev         // Put the partition key were it belongs
10049e08aacStbbdev         std::iter_swap(array + j, first_element);
10149e08aacStbbdev         // array[l..j) is less or equal to key.
10249e08aacStbbdev         // array(j..r) is greater or equal to key.
10349e08aacStbbdev         // array[j] is equal to key
10449e08aacStbbdev         i = j + 1;
10549e08aacStbbdev         std::size_t new_range_size = range.size - i;
10649e08aacStbbdev         range.size = j;
10749e08aacStbbdev         return new_range_size;
10849e08aacStbbdev     }
10949e08aacStbbdev 
11049e08aacStbbdev public:
11149e08aacStbbdev     quick_sort_range() = default;
11249e08aacStbbdev     quick_sort_range( const quick_sort_range& ) = default;
11349e08aacStbbdev     void operator=( const quick_sort_range& ) = delete;
11449e08aacStbbdev 
11549e08aacStbbdev     static constexpr std::size_t grainsize = 500;
11649e08aacStbbdev     const Compare& comp;
11749e08aacStbbdev     std::size_t size;
11849e08aacStbbdev     RandomAccessIterator begin;
11949e08aacStbbdev 
quick_sort_range(RandomAccessIterator begin_,std::size_t size_,const Compare & comp_)12049e08aacStbbdev     quick_sort_range( RandomAccessIterator begin_, std::size_t size_, const Compare& comp_ ) :
12149e08aacStbbdev         comp(comp_), size(size_), begin(begin_) {}
12249e08aacStbbdev 
empty()12349e08aacStbbdev     bool empty() const { return size == 0; }
is_divisible()12449e08aacStbbdev     bool is_divisible() const { return size >= grainsize; }
12549e08aacStbbdev 
quick_sort_range(quick_sort_range & range,split)12649e08aacStbbdev     quick_sort_range( quick_sort_range& range, split )
12749e08aacStbbdev         : comp(range.comp)
12849e08aacStbbdev         , size(split_range(range))
12949e08aacStbbdev           // +1 accounts for the pivot element, which is at its correct place
13049e08aacStbbdev           // already and, therefore, is not included into subranges.
13149e08aacStbbdev         , begin(range.begin + range.size + 1) {}
13249e08aacStbbdev };
13349e08aacStbbdev 
13449e08aacStbbdev //! Body class used to test if elements in a range are presorted
13549e08aacStbbdev /** @ingroup algorithms */
13649e08aacStbbdev template<typename RandomAccessIterator, typename Compare>
13749e08aacStbbdev class quick_sort_pretest_body {
13849e08aacStbbdev     const Compare& comp;
13949e08aacStbbdev     task_group_context& context;
14049e08aacStbbdev 
14149e08aacStbbdev public:
14249e08aacStbbdev     quick_sort_pretest_body() = default;
14349e08aacStbbdev     quick_sort_pretest_body( const quick_sort_pretest_body& ) = default;
14449e08aacStbbdev     void operator=( const quick_sort_pretest_body& ) = delete;
14549e08aacStbbdev 
quick_sort_pretest_body(const Compare & _comp,task_group_context & _context)14649e08aacStbbdev     quick_sort_pretest_body( const Compare& _comp, task_group_context& _context ) : comp(_comp), context(_context) {}
14749e08aacStbbdev 
operator()14849e08aacStbbdev     void operator()( const blocked_range<RandomAccessIterator>& range ) const {
14949e08aacStbbdev         RandomAccessIterator my_end = range.end();
15049e08aacStbbdev 
15149e08aacStbbdev         int i = 0;
15249e08aacStbbdev         //TODO: consider using std::is_sorted() for each 64 iterations (requires performance measurements)
15349e08aacStbbdev         for( RandomAccessIterator k = range.begin(); k != my_end; ++k, ++i ) {
15449e08aacStbbdev             if( i % 64 == 0 && context.is_group_execution_cancelled() ) break;
15549e08aacStbbdev 
15649e08aacStbbdev             // The k - 1 is never out-of-range because the first chunk starts at begin+serial_cutoff+1
15749e08aacStbbdev             if( comp(*(k), *(k - 1)) ) {
15849e08aacStbbdev                 context.cancel_group_execution();
15949e08aacStbbdev                 break;
16049e08aacStbbdev             }
16149e08aacStbbdev         }
16249e08aacStbbdev     }
16349e08aacStbbdev };
16449e08aacStbbdev 
16549e08aacStbbdev //! Body class used to sort elements in a range that is smaller than the grainsize.
16649e08aacStbbdev /** @ingroup algorithms */
16749e08aacStbbdev template<typename RandomAccessIterator, typename Compare>
16849e08aacStbbdev struct quick_sort_body {
operatorquick_sort_body16949e08aacStbbdev     void operator()( const quick_sort_range<RandomAccessIterator,Compare>& range ) const {
17049e08aacStbbdev         std::sort(range.begin, range.begin + range.size, range.comp);
17149e08aacStbbdev     }
17249e08aacStbbdev };
17349e08aacStbbdev 
17449e08aacStbbdev //! Method to perform parallel_for based quick sort.
17549e08aacStbbdev /** @ingroup algorithms */
17649e08aacStbbdev template<typename RandomAccessIterator, typename Compare>
do_parallel_quick_sort(RandomAccessIterator begin,RandomAccessIterator end,const Compare & comp)17749e08aacStbbdev void do_parallel_quick_sort( RandomAccessIterator begin, RandomAccessIterator end, const Compare& comp ) {
17849e08aacStbbdev     parallel_for(quick_sort_range<RandomAccessIterator,Compare>(begin, end - begin, comp),
17949e08aacStbbdev                  quick_sort_body<RandomAccessIterator,Compare>(),
18049e08aacStbbdev                  auto_partitioner());
18149e08aacStbbdev }
18249e08aacStbbdev 
18349e08aacStbbdev //! Wrapper method to initiate the sort by calling parallel_for.
18449e08aacStbbdev /** @ingroup algorithms */
18549e08aacStbbdev template<typename RandomAccessIterator, typename Compare>
parallel_quick_sort(RandomAccessIterator begin,RandomAccessIterator end,const Compare & comp)18649e08aacStbbdev void parallel_quick_sort( RandomAccessIterator begin, RandomAccessIterator end, const Compare& comp ) {
18749e08aacStbbdev     task_group_context my_context(PARALLEL_SORT);
18849e08aacStbbdev     constexpr int serial_cutoff = 9;
18949e08aacStbbdev 
19049e08aacStbbdev     __TBB_ASSERT( begin + serial_cutoff < end, "min_parallel_size is smaller than serial cutoff?" );
19149e08aacStbbdev     RandomAccessIterator k = begin;
19249e08aacStbbdev     for( ; k != begin + serial_cutoff; ++k ) {
19349e08aacStbbdev         if( comp(*(k + 1), *k) ) {
19449e08aacStbbdev             do_parallel_quick_sort(begin, end, comp);
1959052aaabSIvan Kochin             return;
19649e08aacStbbdev         }
19749e08aacStbbdev     }
19849e08aacStbbdev 
19949e08aacStbbdev     // Check is input range already sorted
20049e08aacStbbdev     parallel_for(blocked_range<RandomAccessIterator>(k + 1, end),
20149e08aacStbbdev                  quick_sort_pretest_body<RandomAccessIterator, Compare>(comp, my_context),
20249e08aacStbbdev                  auto_partitioner(),
20349e08aacStbbdev                  my_context);
20449e08aacStbbdev 
20549e08aacStbbdev     if( my_context.is_group_execution_cancelled() )
20649e08aacStbbdev         do_parallel_quick_sort(begin, end, comp);
20749e08aacStbbdev }
20849e08aacStbbdev 
20949e08aacStbbdev /** \page parallel_sort_iter_req Requirements on iterators for parallel_sort
21049e08aacStbbdev     Requirements on the iterator type \c It and its value type \c T for \c parallel_sort:
21149e08aacStbbdev 
21249e08aacStbbdev     - \code void iter_swap( It a, It b ) \endcode Swaps the values of the elements the given
21349e08aacStbbdev     iterators \c a and \c b are pointing to. \c It should be a random access iterator.
21449e08aacStbbdev 
21549e08aacStbbdev     - \code bool Compare::operator()( const T& x, const T& y ) \endcode True if x comes before y;
21649e08aacStbbdev **/
21749e08aacStbbdev 
21849e08aacStbbdev /** \name parallel_sort
21949e08aacStbbdev     See also requirements on \ref parallel_sort_iter_req "iterators for parallel_sort". **/
22049e08aacStbbdev //@{
22149e08aacStbbdev 
222*4eec89feSIvan Kochin #if __TBB_CPP20_CONCEPTS_PRESENT
223*4eec89feSIvan Kochin template<typename It>
224*4eec89feSIvan Kochin using iter_value_type = typename std::iterator_traits<It>::value_type;
225*4eec89feSIvan Kochin 
226*4eec89feSIvan Kochin template<typename Range>
227*4eec89feSIvan Kochin using range_value_type = typename std::iterator_traits<range_iterator_type<Range>>::value_type;
228*4eec89feSIvan Kochin #endif
229*4eec89feSIvan Kochin 
23049e08aacStbbdev //! Sorts the data in [begin,end) using the given comparator
23149e08aacStbbdev /** The compare function object is used for all comparisons between elements during sorting.
23249e08aacStbbdev     The compare object must define a bool operator() function.
23349e08aacStbbdev     @ingroup algorithms **/
23449e08aacStbbdev template<typename RandomAccessIterator, typename Compare>
__TBB_requires(std::random_access_iterator<RandomAccessIterator> && compare<Compare,RandomAccessIterator> && std::movable<iter_value_type<RandomAccessIterator>>)235478de5b1Stbbdev     __TBB_requires(std::random_access_iterator<RandomAccessIterator> &&
236*4eec89feSIvan Kochin                    compare<Compare, RandomAccessIterator> &&
237*4eec89feSIvan Kochin                    std::movable<iter_value_type<RandomAccessIterator>>)
23849e08aacStbbdev void parallel_sort( RandomAccessIterator begin, RandomAccessIterator end, const Compare& comp ) {
23949e08aacStbbdev     constexpr int min_parallel_size = 500;
24049e08aacStbbdev     if( end > begin ) {
24149e08aacStbbdev         if( end - begin < min_parallel_size ) {
24249e08aacStbbdev             std::sort(begin, end, comp);
24349e08aacStbbdev         } else {
24449e08aacStbbdev             parallel_quick_sort(begin, end, comp);
24549e08aacStbbdev         }
24649e08aacStbbdev     }
24749e08aacStbbdev }
24849e08aacStbbdev 
249b95bbc9cSIvan Kochin //! Sorts the data in [begin,end) with a default comparator \c std::less
25049e08aacStbbdev /** @ingroup algorithms **/
25149e08aacStbbdev template<typename RandomAccessIterator>
__TBB_requires(std::random_access_iterator<RandomAccessIterator> && less_than_comparable<iter_value_type<RandomAccessIterator>> && std::movable<iter_value_type<RandomAccessIterator>>)252478de5b1Stbbdev     __TBB_requires(std::random_access_iterator<RandomAccessIterator> &&
253*4eec89feSIvan Kochin                    less_than_comparable<iter_value_type<RandomAccessIterator>> &&
254*4eec89feSIvan Kochin                    std::movable<iter_value_type<RandomAccessIterator>>)
25549e08aacStbbdev void parallel_sort( RandomAccessIterator begin, RandomAccessIterator end ) {
25649e08aacStbbdev     parallel_sort(begin, end, std::less<typename std::iterator_traits<RandomAccessIterator>::value_type>());
25749e08aacStbbdev }
25849e08aacStbbdev 
25949e08aacStbbdev //! Sorts the data in rng using the given comparator
26049e08aacStbbdev /** @ingroup algorithms **/
26149e08aacStbbdev template<typename Range, typename Compare>
__TBB_requires(container_based_sequence<Range,std::random_access_iterator_tag> && compare<Compare,range_iterator_type<Range>> && std::movable<range_value_type<Range>>)262478de5b1Stbbdev     __TBB_requires(container_based_sequence<Range, std::random_access_iterator_tag> &&
263*4eec89feSIvan Kochin                    compare<Compare, range_iterator_type<Range>> &&
264*4eec89feSIvan Kochin                    std::movable<range_value_type<Range>>)
265b95bbc9cSIvan Kochin void parallel_sort( Range&& rng, const Compare& comp ) {
26649e08aacStbbdev     parallel_sort(std::begin(rng), std::end(rng), comp);
26749e08aacStbbdev }
26849e08aacStbbdev 
269b95bbc9cSIvan Kochin //! Sorts the data in rng with a default comparator \c std::less
27049e08aacStbbdev /** @ingroup algorithms **/
27149e08aacStbbdev template<typename Range>
__TBB_requires(container_based_sequence<Range,std::random_access_iterator_tag> && less_than_comparable<range_value_type<Range>> && std::movable<range_value_type<Range>>)272478de5b1Stbbdev     __TBB_requires(container_based_sequence<Range, std::random_access_iterator_tag> &&
273*4eec89feSIvan Kochin                    less_than_comparable<range_value_type<Range>> &&
274*4eec89feSIvan Kochin                    std::movable<range_value_type<Range>>)
275b95bbc9cSIvan Kochin void parallel_sort( Range&& rng ) {
27649e08aacStbbdev     parallel_sort(std::begin(rng), std::end(rng));
27749e08aacStbbdev }
27849e08aacStbbdev //@}
27949e08aacStbbdev 
28049e08aacStbbdev } // namespace d1
28149e08aacStbbdev } // namespace detail
28249e08aacStbbdev 
28349e08aacStbbdev inline namespace v1 {
28449e08aacStbbdev     using detail::d1::parallel_sort;
28549e08aacStbbdev } // namespace v1
28649e08aacStbbdev } // namespace tbb
28749e08aacStbbdev 
28849e08aacStbbdev #endif /*__TBB_parallel_sort_H*/
289