xref: /oneTBB/include/oneapi/tbb/parallel_scan.h (revision a088cfa0)
149e08aacStbbdev /*
2*a088cfa0SKonstantin Boyarinov     Copyright (c) 2005-2023 Intel Corporation
349e08aacStbbdev 
449e08aacStbbdev     Licensed under the Apache License, Version 2.0 (the "License");
549e08aacStbbdev     you may not use this file except in compliance with the License.
649e08aacStbbdev     You may obtain a copy of the License at
749e08aacStbbdev 
849e08aacStbbdev         http://www.apache.org/licenses/LICENSE-2.0
949e08aacStbbdev 
1049e08aacStbbdev     Unless required by applicable law or agreed to in writing, software
1149e08aacStbbdev     distributed under the License is distributed on an "AS IS" BASIS,
1249e08aacStbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1349e08aacStbbdev     See the License for the specific language governing permissions and
1449e08aacStbbdev     limitations under the License.
1549e08aacStbbdev */
1649e08aacStbbdev 
1749e08aacStbbdev #ifndef __TBB_parallel_scan_H
1849e08aacStbbdev #define __TBB_parallel_scan_H
1949e08aacStbbdev 
2049e08aacStbbdev #include <functional>
2149e08aacStbbdev 
2249e08aacStbbdev #include "detail/_config.h"
2349e08aacStbbdev #include "detail/_namespace_injection.h"
2449e08aacStbbdev #include "detail/_exception.h"
2549e08aacStbbdev #include "detail/_task.h"
2649e08aacStbbdev 
2749e08aacStbbdev #include "profiling.h"
2849e08aacStbbdev #include "partitioner.h"
2949e08aacStbbdev #include "blocked_range.h"
3049e08aacStbbdev #include "task_group.h"
3149e08aacStbbdev 
3249e08aacStbbdev namespace tbb {
3349e08aacStbbdev namespace detail {
3449e08aacStbbdev namespace d1 {
3549e08aacStbbdev 
3649e08aacStbbdev //! Used to indicate that the initial scan is being performed.
3749e08aacStbbdev /** @ingroup algorithms */
3849e08aacStbbdev struct pre_scan_tag {
is_final_scanpre_scan_tag3949e08aacStbbdev     static bool is_final_scan() {return false;}
4049e08aacStbbdev     operator bool() {return is_final_scan();}
4149e08aacStbbdev };
4249e08aacStbbdev 
4349e08aacStbbdev //! Used to indicate that the final scan is being performed.
4449e08aacStbbdev /** @ingroup algorithms */
4549e08aacStbbdev struct final_scan_tag {
is_final_scanfinal_scan_tag4649e08aacStbbdev     static bool is_final_scan() {return true;}
4749e08aacStbbdev     operator bool() {return is_final_scan();}
4849e08aacStbbdev };
4949e08aacStbbdev 
5049e08aacStbbdev template<typename Range, typename Body>
5149e08aacStbbdev struct sum_node;
5249e08aacStbbdev 
53478de5b1Stbbdev #if __TBB_CPP20_CONCEPTS_PRESENT
54478de5b1Stbbdev } // namespace d1
55478de5b1Stbbdev namespace d0 {
56478de5b1Stbbdev 
57478de5b1Stbbdev template <typename Body, typename Range>
58478de5b1Stbbdev concept parallel_scan_body = splittable<Body> &&
requires(Body & body,const Range & range,Body & other)59478de5b1Stbbdev                              requires( Body& body, const Range& range, Body& other ) {
60478de5b1Stbbdev                                  body(range, tbb::detail::d1::pre_scan_tag{});
61478de5b1Stbbdev                                  body(range, tbb::detail::d1::final_scan_tag{});
62478de5b1Stbbdev                                  body.reverse_join(other);
63478de5b1Stbbdev                                  body.assign(other);
64478de5b1Stbbdev                              };
65478de5b1Stbbdev 
66478de5b1Stbbdev template <typename Function, typename Range, typename Value>
67*a088cfa0SKonstantin Boyarinov concept parallel_scan_function = std::invocable<const std::remove_reference_t<Function>&,
68*a088cfa0SKonstantin Boyarinov                                                 const Range&, const Value&, bool> &&
69*a088cfa0SKonstantin Boyarinov                                  std::convertible_to<std::invoke_result_t<const std::remove_reference_t<Function>&,
70*a088cfa0SKonstantin Boyarinov                                                                           const Range&, const Value&, bool>,
71*a088cfa0SKonstantin Boyarinov                                                      Value>;
72478de5b1Stbbdev 
73478de5b1Stbbdev template <typename Combine, typename Value>
74*a088cfa0SKonstantin Boyarinov concept parallel_scan_combine = std::invocable<const std::remove_reference_t<Combine>&,
75*a088cfa0SKonstantin Boyarinov                                                const Value&, const Value&> &&
76*a088cfa0SKonstantin Boyarinov                                 std::convertible_to<std::invoke_result_t<const std::remove_reference_t<Combine>&,
77*a088cfa0SKonstantin Boyarinov                                                                          const Value&, const Value&>,
78*a088cfa0SKonstantin Boyarinov                                                     Value>;
79478de5b1Stbbdev 
80478de5b1Stbbdev } // namespace d0
81478de5b1Stbbdev namespace d1 {
82478de5b1Stbbdev #endif // __TBB_CPP20_CONCEPTS_PRESENT
83478de5b1Stbbdev 
8449e08aacStbbdev //! Performs final scan for a leaf
8549e08aacStbbdev /** @ingroup algorithms */
8649e08aacStbbdev template<typename Range, typename Body>
8749e08aacStbbdev struct final_sum : public task {
8849e08aacStbbdev private:
8949e08aacStbbdev     using sum_node_type = sum_node<Range, Body>;
9049e08aacStbbdev     Body m_body;
9149e08aacStbbdev     aligned_space<Range> m_range;
9249e08aacStbbdev     //! Where to put result of last subrange, or nullptr if not last subrange.
9349e08aacStbbdev     Body* m_stuff_last;
9449e08aacStbbdev 
9549e08aacStbbdev     wait_context& m_wait_context;
9649e08aacStbbdev     sum_node_type* m_parent = nullptr;
9749e08aacStbbdev public:
9849e08aacStbbdev     small_object_allocator m_allocator;
final_sumfinal_sum9949e08aacStbbdev     final_sum( Body& body, wait_context& w_o, small_object_allocator& alloc ) :
10049e08aacStbbdev         m_body(body, split()), m_wait_context(w_o), m_allocator(alloc) {
10149e08aacStbbdev         poison_pointer(m_stuff_last);
10249e08aacStbbdev     }
10349e08aacStbbdev 
final_sumfinal_sum10449e08aacStbbdev     final_sum( final_sum& sum, small_object_allocator& alloc ) :
10549e08aacStbbdev         m_body(sum.m_body, split()), m_wait_context(sum.m_wait_context), m_allocator(alloc) {
10649e08aacStbbdev         poison_pointer(m_stuff_last);
10749e08aacStbbdev     }
10849e08aacStbbdev 
~final_sumfinal_sum10949e08aacStbbdev     ~final_sum() {
11049e08aacStbbdev         m_range.begin()->~Range();
11149e08aacStbbdev     }
finish_constructionfinal_sum11249e08aacStbbdev     void finish_construction( sum_node_type* parent, const Range& range, Body* stuff_last ) {
11349e08aacStbbdev         __TBB_ASSERT( m_parent == nullptr, nullptr );
11449e08aacStbbdev         m_parent = parent;
11549e08aacStbbdev         new( m_range.begin() ) Range(range);
11649e08aacStbbdev         m_stuff_last = stuff_last;
11749e08aacStbbdev     }
11849e08aacStbbdev private:
release_parentfinal_sum11949e08aacStbbdev     sum_node_type* release_parent() {
12049e08aacStbbdev         call_itt_task_notify(releasing, m_parent);
12149e08aacStbbdev         if (m_parent) {
12249e08aacStbbdev             auto parent = m_parent;
12349e08aacStbbdev             m_parent = nullptr;
124478de5b1Stbbdev             if (parent->ref_count.fetch_sub(1) == 1) {
12549e08aacStbbdev                 return parent;
12649e08aacStbbdev             }
12749e08aacStbbdev         }
12849e08aacStbbdev         else
12949e08aacStbbdev             m_wait_context.release();
13049e08aacStbbdev         return nullptr;
13149e08aacStbbdev     }
finalizefinal_sum13249e08aacStbbdev     sum_node_type* finalize(const execution_data& ed){
13349e08aacStbbdev         sum_node_type* next_task = release_parent();
13449e08aacStbbdev         m_allocator.delete_object<final_sum>(this, ed);
13549e08aacStbbdev         return next_task;
13649e08aacStbbdev     }
13749e08aacStbbdev 
13849e08aacStbbdev public:
executefinal_sum13949e08aacStbbdev     task* execute(execution_data& ed) override {
14049e08aacStbbdev         m_body( *m_range.begin(), final_scan_tag() );
14149e08aacStbbdev         if( m_stuff_last )
14249e08aacStbbdev             m_stuff_last->assign(m_body);
14349e08aacStbbdev 
14449e08aacStbbdev         return finalize(ed);
14549e08aacStbbdev     }
cancelfinal_sum14649e08aacStbbdev     task* cancel(execution_data& ed) override {
14749e08aacStbbdev         return finalize(ed);
14849e08aacStbbdev     }
14949e08aacStbbdev     template<typename Tag>
operatorfinal_sum15049e08aacStbbdev     void operator()( const Range& r, Tag tag ) {
15149e08aacStbbdev         m_body( r, tag );
15249e08aacStbbdev     }
reverse_joinfinal_sum15349e08aacStbbdev     void reverse_join( final_sum& a ) {
15449e08aacStbbdev         m_body.reverse_join(a.m_body);
15549e08aacStbbdev     }
reverse_joinfinal_sum15649e08aacStbbdev     void reverse_join( Body& body ) {
15749e08aacStbbdev         m_body.reverse_join(body);
15849e08aacStbbdev     }
assign_tofinal_sum15949e08aacStbbdev     void assign_to( Body& body ) {
16049e08aacStbbdev         body.assign(m_body);
16149e08aacStbbdev     }
self_destroyfinal_sum16249e08aacStbbdev     void self_destroy(const execution_data& ed) {
16349e08aacStbbdev         m_allocator.delete_object<final_sum>(this, ed);
16449e08aacStbbdev     }
16549e08aacStbbdev };
16649e08aacStbbdev 
16749e08aacStbbdev //! Split work to be done in the scan.
16849e08aacStbbdev /** @ingroup algorithms */
16949e08aacStbbdev template<typename Range, typename Body>
17049e08aacStbbdev struct sum_node : public task {
17149e08aacStbbdev private:
17249e08aacStbbdev     using final_sum_type = final_sum<Range,Body>;
17349e08aacStbbdev public:
17449e08aacStbbdev     final_sum_type *m_incoming;
17549e08aacStbbdev     final_sum_type *m_body;
17649e08aacStbbdev     Body *m_stuff_last;
17749e08aacStbbdev private:
17849e08aacStbbdev     final_sum_type *m_left_sum;
17949e08aacStbbdev     sum_node *m_left;
18049e08aacStbbdev     sum_node *m_right;
18149e08aacStbbdev     bool m_left_is_final;
18249e08aacStbbdev     Range m_range;
18349e08aacStbbdev     wait_context& m_wait_context;
18449e08aacStbbdev     sum_node* m_parent;
18549e08aacStbbdev     small_object_allocator m_allocator;
18649e08aacStbbdev public:
18749e08aacStbbdev     std::atomic<unsigned int> ref_count{0};
sum_nodesum_node18849e08aacStbbdev     sum_node( const Range range, bool left_is_final_, sum_node* parent, wait_context& w_o, small_object_allocator& alloc ) :
18949e08aacStbbdev         m_stuff_last(nullptr),
19049e08aacStbbdev         m_left_sum(nullptr),
19149e08aacStbbdev         m_left(nullptr),
19249e08aacStbbdev         m_right(nullptr),
19349e08aacStbbdev         m_left_is_final(left_is_final_),
19449e08aacStbbdev         m_range(range),
19549e08aacStbbdev         m_wait_context(w_o),
19649e08aacStbbdev         m_parent(parent),
19749e08aacStbbdev         m_allocator(alloc)
19849e08aacStbbdev     {
19949e08aacStbbdev         if( m_parent )
200478de5b1Stbbdev             m_parent->ref_count.fetch_add(1);
20149e08aacStbbdev         // Poison fields that will be set by second pass.
20249e08aacStbbdev         poison_pointer(m_body);
20349e08aacStbbdev         poison_pointer(m_incoming);
20449e08aacStbbdev     }
20549e08aacStbbdev 
~sum_nodesum_node20649e08aacStbbdev     ~sum_node() {
20749e08aacStbbdev         if (m_parent)
208478de5b1Stbbdev             m_parent->ref_count.fetch_sub(1);
20949e08aacStbbdev     }
21049e08aacStbbdev private:
release_parentsum_node21149e08aacStbbdev     sum_node* release_parent() {
21249e08aacStbbdev         call_itt_task_notify(releasing, m_parent);
21349e08aacStbbdev         if (m_parent) {
21449e08aacStbbdev             auto parent = m_parent;
21549e08aacStbbdev             m_parent = nullptr;
216478de5b1Stbbdev             if (parent->ref_count.fetch_sub(1) == 1) {
21749e08aacStbbdev                 return parent;
21849e08aacStbbdev             }
21949e08aacStbbdev         }
22049e08aacStbbdev         else
22149e08aacStbbdev             m_wait_context.release();
22249e08aacStbbdev         return nullptr;
22349e08aacStbbdev     }
create_childsum_node22449e08aacStbbdev     task* create_child( const Range& range, final_sum_type& body, sum_node* child, final_sum_type* incoming, Body* stuff_last ) {
22549e08aacStbbdev         if( child ) {
22649e08aacStbbdev             __TBB_ASSERT( is_poisoned(child->m_body) && is_poisoned(child->m_incoming), nullptr );
22749e08aacStbbdev             child->prepare_for_execution(body, incoming, stuff_last);
22849e08aacStbbdev             return child;
22949e08aacStbbdev         } else {
23049e08aacStbbdev             body.finish_construction(this, range, stuff_last);
23149e08aacStbbdev             return &body;
23249e08aacStbbdev         }
23349e08aacStbbdev     }
23449e08aacStbbdev 
finalizesum_node23549e08aacStbbdev     sum_node* finalize(const execution_data& ed) {
23649e08aacStbbdev         sum_node* next_task = release_parent();
23749e08aacStbbdev         m_allocator.delete_object<sum_node>(this, ed);
23849e08aacStbbdev         return next_task;
23949e08aacStbbdev     }
24049e08aacStbbdev 
24149e08aacStbbdev public:
prepare_for_executionsum_node24249e08aacStbbdev     void prepare_for_execution(final_sum_type& body, final_sum_type* incoming, Body *stuff_last) {
24349e08aacStbbdev         this->m_body = &body;
24449e08aacStbbdev         this->m_incoming = incoming;
24549e08aacStbbdev         this->m_stuff_last = stuff_last;
24649e08aacStbbdev     }
executesum_node24749e08aacStbbdev     task* execute(execution_data& ed) override {
24849e08aacStbbdev         if( m_body ) {
24949e08aacStbbdev             if( m_incoming )
25049e08aacStbbdev                 m_left_sum->reverse_join( *m_incoming );
25149e08aacStbbdev             task* right_child = this->create_child(Range(m_range,split()), *m_left_sum, m_right, m_left_sum, m_stuff_last);
25249e08aacStbbdev             task* left_child = m_left_is_final ? nullptr : this->create_child(m_range, *m_body, m_left, m_incoming, nullptr);
25349e08aacStbbdev             ref_count = (left_child != nullptr) + (right_child != nullptr);
25449e08aacStbbdev             m_body = nullptr;
25549e08aacStbbdev             if( left_child ) {
25649e08aacStbbdev                 spawn(*right_child, *ed.context);
25749e08aacStbbdev                 return left_child;
25849e08aacStbbdev             } else {
25949e08aacStbbdev                 return right_child;
26049e08aacStbbdev             }
26149e08aacStbbdev         } else {
26249e08aacStbbdev             return finalize(ed);
26349e08aacStbbdev         }
26449e08aacStbbdev     }
cancelsum_node26549e08aacStbbdev     task* cancel(execution_data& ed) override {
26649e08aacStbbdev         return finalize(ed);
26749e08aacStbbdev     }
self_destroysum_node26849e08aacStbbdev     void self_destroy(const execution_data& ed) {
26949e08aacStbbdev         m_allocator.delete_object<sum_node>(this, ed);
27049e08aacStbbdev     }
27149e08aacStbbdev     template<typename range,typename body,typename partitioner>
27249e08aacStbbdev     friend struct start_scan;
27349e08aacStbbdev 
27449e08aacStbbdev     template<typename range,typename body>
27549e08aacStbbdev     friend struct finish_scan;
27649e08aacStbbdev };
27749e08aacStbbdev 
27849e08aacStbbdev //! Combine partial results
27949e08aacStbbdev /** @ingroup algorithms */
28049e08aacStbbdev template<typename Range, typename Body>
28149e08aacStbbdev struct finish_scan : public task {
28249e08aacStbbdev private:
28349e08aacStbbdev     using sum_node_type = sum_node<Range,Body>;
28449e08aacStbbdev     using final_sum_type = final_sum<Range,Body>;
28549e08aacStbbdev     final_sum_type** const m_sum_slot;
28649e08aacStbbdev     sum_node_type*& m_return_slot;
28749e08aacStbbdev     small_object_allocator m_allocator;
28849e08aacStbbdev public:
289478de5b1Stbbdev     std::atomic<final_sum_type*> m_right_zombie;
29049e08aacStbbdev     sum_node_type& m_result;
29149e08aacStbbdev     std::atomic<unsigned int> ref_count{2};
29249e08aacStbbdev     finish_scan*  m_parent;
29349e08aacStbbdev     wait_context& m_wait_context;
executefinish_scan29449e08aacStbbdev     task* execute(execution_data& ed) override {
29549e08aacStbbdev         __TBB_ASSERT( m_result.ref_count.load() == static_cast<unsigned int>((m_result.m_left!=nullptr)+(m_result.m_right!=nullptr)), nullptr );
29649e08aacStbbdev         if( m_result.m_left )
29749e08aacStbbdev             m_result.m_left_is_final = false;
298478de5b1Stbbdev         final_sum_type* right_zombie = m_right_zombie.load(std::memory_order_acquire);
299478de5b1Stbbdev         if( right_zombie && m_sum_slot )
30049e08aacStbbdev             (*m_sum_slot)->reverse_join(*m_result.m_left_sum);
30149e08aacStbbdev         __TBB_ASSERT( !m_return_slot, nullptr );
302478de5b1Stbbdev         if( right_zombie || m_result.m_right ) {
30349e08aacStbbdev             m_return_slot = &m_result;
30449e08aacStbbdev         } else {
30549e08aacStbbdev             m_result.self_destroy(ed);
30649e08aacStbbdev         }
307478de5b1Stbbdev         if( right_zombie && !m_sum_slot && !m_result.m_right ) {
308478de5b1Stbbdev             right_zombie->self_destroy(ed);
309478de5b1Stbbdev             m_right_zombie.store(nullptr, std::memory_order_relaxed);
31049e08aacStbbdev         }
31149e08aacStbbdev         return finalize(ed);
31249e08aacStbbdev     }
cancelfinish_scan31349e08aacStbbdev     task* cancel(execution_data& ed) override {
31449e08aacStbbdev         return finalize(ed);
31549e08aacStbbdev     }
finish_scanfinish_scan31649e08aacStbbdev     finish_scan(sum_node_type*& return_slot, final_sum_type** sum, sum_node_type& result_, finish_scan* parent, wait_context& w_o, small_object_allocator& alloc) :
31749e08aacStbbdev         m_sum_slot(sum),
31849e08aacStbbdev         m_return_slot(return_slot),
31949e08aacStbbdev         m_allocator(alloc),
32049e08aacStbbdev         m_right_zombie(nullptr),
32149e08aacStbbdev         m_result(result_),
32249e08aacStbbdev         m_parent(parent),
32349e08aacStbbdev         m_wait_context(w_o)
32449e08aacStbbdev     {
32549e08aacStbbdev         __TBB_ASSERT( !m_return_slot, nullptr );
32649e08aacStbbdev     }
32749e08aacStbbdev private:
release_parentfinish_scan32849e08aacStbbdev     finish_scan* release_parent() {
32949e08aacStbbdev         call_itt_task_notify(releasing, m_parent);
33049e08aacStbbdev         if (m_parent) {
33149e08aacStbbdev             auto parent = m_parent;
33249e08aacStbbdev             m_parent = nullptr;
333478de5b1Stbbdev             if (parent->ref_count.fetch_sub(1) == 1) {
33449e08aacStbbdev                 return parent;
33549e08aacStbbdev             }
33649e08aacStbbdev         }
33749e08aacStbbdev         else
33849e08aacStbbdev             m_wait_context.release();
33949e08aacStbbdev         return nullptr;
34049e08aacStbbdev     }
finalizefinish_scan34149e08aacStbbdev     finish_scan* finalize(const execution_data& ed) {
34249e08aacStbbdev         finish_scan* next_task = release_parent();
34349e08aacStbbdev         m_allocator.delete_object<finish_scan>(this, ed);
34449e08aacStbbdev         return next_task;
34549e08aacStbbdev     }
34649e08aacStbbdev };
34749e08aacStbbdev 
34849e08aacStbbdev //! Initial task to split the work
34949e08aacStbbdev /** @ingroup algorithms */
35049e08aacStbbdev template<typename Range, typename Body, typename Partitioner>
35149e08aacStbbdev struct start_scan : public task {
35249e08aacStbbdev private:
35349e08aacStbbdev     using sum_node_type = sum_node<Range,Body>;
35449e08aacStbbdev     using final_sum_type = final_sum<Range,Body>;
35549e08aacStbbdev     using finish_pass1_type = finish_scan<Range,Body>;
35649e08aacStbbdev     std::reference_wrapper<sum_node_type*> m_return_slot;
35749e08aacStbbdev     Range m_range;
35849e08aacStbbdev     std::reference_wrapper<final_sum_type> m_body;
35949e08aacStbbdev     typename Partitioner::partition_type m_partition;
36049e08aacStbbdev     /** Non-null if caller is requesting total. */
36149e08aacStbbdev     final_sum_type** m_sum_slot;
36249e08aacStbbdev     bool m_is_final;
36349e08aacStbbdev     bool m_is_right_child;
36449e08aacStbbdev 
36549e08aacStbbdev     finish_pass1_type*  m_parent;
36649e08aacStbbdev     small_object_allocator m_allocator;
36749e08aacStbbdev     wait_context& m_wait_context;
36849e08aacStbbdev 
release_parentstart_scan36949e08aacStbbdev     finish_pass1_type* release_parent() {
37049e08aacStbbdev         call_itt_task_notify(releasing, m_parent);
37149e08aacStbbdev         if (m_parent) {
37249e08aacStbbdev             auto parent = m_parent;
37349e08aacStbbdev             m_parent = nullptr;
374478de5b1Stbbdev             if (parent->ref_count.fetch_sub(1) == 1) {
37549e08aacStbbdev                 return parent;
37649e08aacStbbdev             }
37749e08aacStbbdev         }
37849e08aacStbbdev         else
37949e08aacStbbdev             m_wait_context.release();
38049e08aacStbbdev         return nullptr;
38149e08aacStbbdev     }
38249e08aacStbbdev 
finalizestart_scan38349e08aacStbbdev     finish_pass1_type* finalize( const execution_data& ed ) {
38449e08aacStbbdev         finish_pass1_type* next_task = release_parent();
38549e08aacStbbdev         m_allocator.delete_object<start_scan>(this, ed);
38649e08aacStbbdev         return next_task;
38749e08aacStbbdev     }
38849e08aacStbbdev 
38949e08aacStbbdev public:
39049e08aacStbbdev     task* execute( execution_data& ) override;
cancelstart_scan39149e08aacStbbdev     task* cancel( execution_data& ed ) override {
39249e08aacStbbdev         return finalize(ed);
39349e08aacStbbdev     }
start_scanstart_scan39449e08aacStbbdev     start_scan( sum_node_type*& return_slot, start_scan& parent, small_object_allocator& alloc ) :
39549e08aacStbbdev         m_return_slot(return_slot),
39649e08aacStbbdev         m_range(parent.m_range,split()),
39749e08aacStbbdev         m_body(parent.m_body),
39849e08aacStbbdev         m_partition(parent.m_partition,split()),
39949e08aacStbbdev         m_sum_slot(parent.m_sum_slot),
40049e08aacStbbdev         m_is_final(parent.m_is_final),
40149e08aacStbbdev         m_is_right_child(true),
40249e08aacStbbdev         m_parent(parent.m_parent),
40349e08aacStbbdev         m_allocator(alloc),
40449e08aacStbbdev         m_wait_context(parent.m_wait_context)
40549e08aacStbbdev     {
40649e08aacStbbdev         __TBB_ASSERT( !m_return_slot, nullptr );
40749e08aacStbbdev         parent.m_is_right_child = false;
40849e08aacStbbdev     }
40949e08aacStbbdev 
start_scanstart_scan41049e08aacStbbdev     start_scan( sum_node_type*& return_slot, const Range& range, final_sum_type& body, const Partitioner& partitioner, wait_context& w_o, small_object_allocator& alloc ) :
41149e08aacStbbdev         m_return_slot(return_slot),
41249e08aacStbbdev         m_range(range),
41349e08aacStbbdev         m_body(body),
41449e08aacStbbdev         m_partition(partitioner),
41549e08aacStbbdev         m_sum_slot(nullptr),
41649e08aacStbbdev         m_is_final(true),
41749e08aacStbbdev         m_is_right_child(false),
41849e08aacStbbdev         m_parent(nullptr),
41949e08aacStbbdev         m_allocator(alloc),
42049e08aacStbbdev         m_wait_context(w_o)
42149e08aacStbbdev     {
42249e08aacStbbdev         __TBB_ASSERT( !m_return_slot, nullptr );
42349e08aacStbbdev     }
42449e08aacStbbdev 
runstart_scan42549e08aacStbbdev     static void run( const Range& range, Body& body, const Partitioner& partitioner ) {
42649e08aacStbbdev         if( !range.empty() ) {
42749e08aacStbbdev             task_group_context context(PARALLEL_SCAN);
42849e08aacStbbdev 
42949e08aacStbbdev             using start_pass1_type = start_scan<Range,Body,Partitioner>;
43049e08aacStbbdev             sum_node_type* root = nullptr;
43149e08aacStbbdev             wait_context w_ctx{1};
43249e08aacStbbdev             small_object_allocator alloc{};
43349e08aacStbbdev 
43449e08aacStbbdev             auto& temp_body = *alloc.new_object<final_sum_type>(body, w_ctx, alloc);
43549e08aacStbbdev             temp_body.reverse_join(body);
43649e08aacStbbdev 
43749e08aacStbbdev             auto& pass1 = *alloc.new_object<start_pass1_type>(/*m_return_slot=*/root, range, temp_body, partitioner, w_ctx, alloc);
43849e08aacStbbdev 
43949e08aacStbbdev             execute_and_wait(pass1, context, w_ctx, context);
44049e08aacStbbdev             if( root ) {
44149e08aacStbbdev                 root->prepare_for_execution(temp_body, nullptr, &body);
44249e08aacStbbdev                 w_ctx.reserve();
44349e08aacStbbdev                 execute_and_wait(*root, context, w_ctx, context);
44449e08aacStbbdev             } else {
44549e08aacStbbdev                 temp_body.assign_to(body);
44649e08aacStbbdev                 temp_body.finish_construction(nullptr, range, nullptr);
44749e08aacStbbdev                 alloc.delete_object<final_sum_type>(&temp_body);
44849e08aacStbbdev             }
44949e08aacStbbdev         }
45049e08aacStbbdev     }
45149e08aacStbbdev };
45249e08aacStbbdev 
45349e08aacStbbdev template<typename Range, typename Body, typename Partitioner>
execute(execution_data & ed)45449e08aacStbbdev task* start_scan<Range,Body,Partitioner>::execute( execution_data& ed ) {
45549e08aacStbbdev     // Inspecting m_parent->result.left_sum would ordinarily be a race condition.
45649e08aacStbbdev     // But we inspect it only if we are not a stolen task, in which case we
45749e08aacStbbdev     // know that task assigning to m_parent->result.left_sum has completed.
45849e08aacStbbdev     __TBB_ASSERT(!m_is_right_child || m_parent, "right child is never an orphan");
45949e08aacStbbdev     bool treat_as_stolen = m_is_right_child && (is_stolen(ed) || &m_body.get()!=m_parent->m_result.m_left_sum);
46049e08aacStbbdev     if( treat_as_stolen ) {
46149e08aacStbbdev         // Invocation is for right child that has been really stolen or needs to be virtually stolen
46249e08aacStbbdev         small_object_allocator alloc{};
463478de5b1Stbbdev         final_sum_type* right_zombie = alloc.new_object<final_sum_type>(m_body, alloc);
464478de5b1Stbbdev         m_parent->m_right_zombie.store(right_zombie, std::memory_order_release);
465478de5b1Stbbdev         m_body = *right_zombie;
46649e08aacStbbdev         m_is_final = false;
46749e08aacStbbdev     }
46849e08aacStbbdev     task* next_task = nullptr;
46949e08aacStbbdev     if( (m_is_right_child && !treat_as_stolen) || !m_range.is_divisible() || m_partition.should_execute_range(ed) ) {
47049e08aacStbbdev         if( m_is_final )
47149e08aacStbbdev             m_body(m_range, final_scan_tag());
47249e08aacStbbdev         else if( m_sum_slot )
47349e08aacStbbdev             m_body(m_range, pre_scan_tag());
47449e08aacStbbdev         if( m_sum_slot )
47549e08aacStbbdev             *m_sum_slot = &m_body.get();
47649e08aacStbbdev         __TBB_ASSERT( !m_return_slot, nullptr );
47749e08aacStbbdev 
47849e08aacStbbdev         next_task = finalize(ed);
47949e08aacStbbdev     } else {
48049e08aacStbbdev         small_object_allocator alloc{};
48149e08aacStbbdev         auto result = alloc.new_object<sum_node_type>(m_range,/*m_left_is_final=*/m_is_final, m_parent? &m_parent->m_result: nullptr, m_wait_context, alloc);
48249e08aacStbbdev 
48349e08aacStbbdev         auto new_parent = alloc.new_object<finish_pass1_type>(m_return_slot, m_sum_slot, *result, m_parent, m_wait_context, alloc);
48449e08aacStbbdev         m_parent = new_parent;
48549e08aacStbbdev 
48649e08aacStbbdev         // Split off right child
48749e08aacStbbdev         auto& right_child = *alloc.new_object<start_scan>(/*m_return_slot=*/result->m_right, *this, alloc);
48849e08aacStbbdev 
48949e08aacStbbdev         spawn(right_child, *ed.context);
49049e08aacStbbdev 
49149e08aacStbbdev         m_sum_slot = &result->m_left_sum;
49249e08aacStbbdev         m_return_slot = result->m_left;
49349e08aacStbbdev 
49449e08aacStbbdev         __TBB_ASSERT( !m_return_slot, nullptr );
49549e08aacStbbdev         next_task = this;
49649e08aacStbbdev     }
49749e08aacStbbdev     return next_task;
49849e08aacStbbdev }
49949e08aacStbbdev 
50049e08aacStbbdev template<typename Range, typename Value, typename Scan, typename ReverseJoin>
50149e08aacStbbdev class lambda_scan_body {
50249e08aacStbbdev     Value               m_sum_slot;
50349e08aacStbbdev     const Value&        identity_element;
50449e08aacStbbdev     const Scan&         m_scan;
50549e08aacStbbdev     const ReverseJoin&  m_reverse_join;
50649e08aacStbbdev public:
50749e08aacStbbdev     void operator=(const lambda_scan_body&) = delete;
50849e08aacStbbdev     lambda_scan_body(const lambda_scan_body&) = default;
50949e08aacStbbdev 
lambda_scan_body(const Value & identity,const Scan & scan,const ReverseJoin & rev_join)51049e08aacStbbdev     lambda_scan_body( const Value& identity, const Scan& scan, const ReverseJoin& rev_join )
51149e08aacStbbdev         : m_sum_slot(identity)
51249e08aacStbbdev         , identity_element(identity)
51349e08aacStbbdev         , m_scan(scan)
51449e08aacStbbdev         , m_reverse_join(rev_join) {}
51549e08aacStbbdev 
lambda_scan_body(lambda_scan_body & b,split)51649e08aacStbbdev     lambda_scan_body( lambda_scan_body& b, split )
51749e08aacStbbdev         : m_sum_slot(b.identity_element)
51849e08aacStbbdev         , identity_element(b.identity_element)
51949e08aacStbbdev         , m_scan(b.m_scan)
52049e08aacStbbdev         , m_reverse_join(b.m_reverse_join) {}
52149e08aacStbbdev 
52249e08aacStbbdev     template<typename Tag>
operator()52349e08aacStbbdev     void operator()( const Range& r, Tag tag ) {
524*a088cfa0SKonstantin Boyarinov         m_sum_slot = tbb::detail::invoke(m_scan, r, m_sum_slot, tag);
52549e08aacStbbdev     }
52649e08aacStbbdev 
reverse_join(lambda_scan_body & a)52749e08aacStbbdev     void reverse_join( lambda_scan_body& a ) {
528*a088cfa0SKonstantin Boyarinov         m_sum_slot = tbb::detail::invoke(m_reverse_join, a.m_sum_slot, m_sum_slot);
52949e08aacStbbdev     }
53049e08aacStbbdev 
assign(lambda_scan_body & b)53149e08aacStbbdev     void assign( lambda_scan_body& b ) {
53249e08aacStbbdev         m_sum_slot = b.m_sum_slot;
53349e08aacStbbdev     }
53449e08aacStbbdev 
result()53549e08aacStbbdev     Value result() const {
53649e08aacStbbdev         return m_sum_slot;
53749e08aacStbbdev     }
53849e08aacStbbdev };
53949e08aacStbbdev 
54049e08aacStbbdev // Requirements on Range concept are documented in blocked_range.h
54149e08aacStbbdev 
54249e08aacStbbdev /** \page parallel_scan_body_req Requirements on parallel_scan body
54349e08aacStbbdev     Class \c Body implementing the concept of parallel_scan body must define:
54449e08aacStbbdev     - \code Body::Body( Body&, split ); \endcode    Splitting constructor.
54549e08aacStbbdev                                                     Split \c b so that \c this and \c b can accumulate separately
54649e08aacStbbdev     - \code Body::~Body(); \endcode                 Destructor
54749e08aacStbbdev     - \code void Body::operator()( const Range& r, pre_scan_tag ); \endcode
54849e08aacStbbdev                                                     Preprocess iterations for range \c r
54949e08aacStbbdev     - \code void Body::operator()( const Range& r, final_scan_tag ); \endcode
55049e08aacStbbdev                                                     Do final processing for iterations of range \c r
55149e08aacStbbdev     - \code void Body::reverse_join( Body& a ); \endcode
55249e08aacStbbdev                                                     Merge preprocessing state of \c a into \c this, where \c a was
55349e08aacStbbdev                                                     created earlier from \c b by b's splitting constructor
55449e08aacStbbdev **/
55549e08aacStbbdev 
55649e08aacStbbdev /** \name parallel_scan
55749e08aacStbbdev     See also requirements on \ref range_req "Range" and \ref parallel_scan_body_req "parallel_scan Body". **/
55849e08aacStbbdev //@{
55949e08aacStbbdev 
56049e08aacStbbdev //! Parallel prefix with default partitioner
56149e08aacStbbdev /** @ingroup algorithms **/
56249e08aacStbbdev template<typename Range, typename Body>
__TBB_requires(tbb_range<Range> && parallel_scan_body<Body,Range>)5634a23d002Skboyarinov     __TBB_requires(tbb_range<Range> && parallel_scan_body<Body, Range>)
56449e08aacStbbdev void parallel_scan( const Range& range, Body& body ) {
56549e08aacStbbdev     start_scan<Range, Body, auto_partitioner>::run(range,body,__TBB_DEFAULT_PARTITIONER());
56649e08aacStbbdev }
56749e08aacStbbdev 
56849e08aacStbbdev //! Parallel prefix with simple_partitioner
56949e08aacStbbdev /** @ingroup algorithms **/
57049e08aacStbbdev template<typename Range, typename Body>
__TBB_requires(tbb_range<Range> && parallel_scan_body<Body,Range>)5714a23d002Skboyarinov     __TBB_requires(tbb_range<Range> && parallel_scan_body<Body, Range>)
57249e08aacStbbdev void parallel_scan( const Range& range, Body& body, const simple_partitioner& partitioner ) {
57349e08aacStbbdev     start_scan<Range, Body, simple_partitioner>::run(range, body, partitioner);
57449e08aacStbbdev }
57549e08aacStbbdev 
57649e08aacStbbdev //! Parallel prefix with auto_partitioner
57749e08aacStbbdev /** @ingroup algorithms **/
57849e08aacStbbdev template<typename Range, typename Body>
__TBB_requires(tbb_range<Range> && parallel_scan_body<Body,Range>)5794a23d002Skboyarinov     __TBB_requires(tbb_range<Range> && parallel_scan_body<Body, Range>)
58049e08aacStbbdev void parallel_scan( const Range& range, Body& body, const auto_partitioner& partitioner ) {
58149e08aacStbbdev     start_scan<Range,Body,auto_partitioner>::run(range, body, partitioner);
58249e08aacStbbdev }
58349e08aacStbbdev 
58449e08aacStbbdev //! Parallel prefix with default partitioner
58549e08aacStbbdev /** @ingroup algorithms **/
58649e08aacStbbdev template<typename Range, typename Value, typename Scan, typename ReverseJoin>
__TBB_requires(tbb_range<Range> && parallel_scan_function<Scan,Range,Value> && parallel_scan_combine<ReverseJoin,Value>)5874a23d002Skboyarinov     __TBB_requires(tbb_range<Range> && parallel_scan_function<Scan, Range, Value> &&
588478de5b1Stbbdev                    parallel_scan_combine<ReverseJoin, Value>)
58949e08aacStbbdev Value parallel_scan( const Range& range, const Value& identity, const Scan& scan, const ReverseJoin& reverse_join ) {
59049e08aacStbbdev     lambda_scan_body<Range, Value, Scan, ReverseJoin> body(identity, scan, reverse_join);
59149e08aacStbbdev     parallel_scan(range, body, __TBB_DEFAULT_PARTITIONER());
59249e08aacStbbdev     return body.result();
59349e08aacStbbdev }
59449e08aacStbbdev 
59549e08aacStbbdev //! Parallel prefix with simple_partitioner
59649e08aacStbbdev /** @ingroup algorithms **/
59749e08aacStbbdev template<typename Range, typename Value, typename Scan, typename ReverseJoin>
__TBB_requires(tbb_range<Range> && parallel_scan_function<Scan,Range,Value> && parallel_scan_combine<ReverseJoin,Value>)5984a23d002Skboyarinov     __TBB_requires(tbb_range<Range> && parallel_scan_function<Scan, Range, Value> &&
599478de5b1Stbbdev                    parallel_scan_combine<ReverseJoin, Value>)
60049e08aacStbbdev Value parallel_scan( const Range& range, const Value& identity, const Scan& scan, const ReverseJoin& reverse_join,
60149e08aacStbbdev                      const simple_partitioner& partitioner ) {
60249e08aacStbbdev     lambda_scan_body<Range, Value, Scan, ReverseJoin> body(identity, scan, reverse_join);
60349e08aacStbbdev     parallel_scan(range, body, partitioner);
60449e08aacStbbdev     return body.result();
60549e08aacStbbdev }
60649e08aacStbbdev 
60749e08aacStbbdev //! Parallel prefix with auto_partitioner
60849e08aacStbbdev /** @ingroup algorithms **/
60949e08aacStbbdev template<typename Range, typename Value, typename Scan, typename ReverseJoin>
__TBB_requires(tbb_range<Range> && parallel_scan_function<Scan,Range,Value> && parallel_scan_combine<ReverseJoin,Value>)6104a23d002Skboyarinov     __TBB_requires(tbb_range<Range> && parallel_scan_function<Scan, Range, Value> &&
611478de5b1Stbbdev                    parallel_scan_combine<ReverseJoin, Value>)
61249e08aacStbbdev Value parallel_scan( const Range& range, const Value& identity, const Scan& scan, const ReverseJoin& reverse_join,
61349e08aacStbbdev                      const auto_partitioner& partitioner ) {
61449e08aacStbbdev     lambda_scan_body<Range, Value, Scan, ReverseJoin> body(identity, scan, reverse_join);
61549e08aacStbbdev     parallel_scan(range, body, partitioner);
61649e08aacStbbdev     return body.result();
61749e08aacStbbdev }
61849e08aacStbbdev 
61949e08aacStbbdev } // namespace d1
62049e08aacStbbdev } // namespace detail
62149e08aacStbbdev 
62249e08aacStbbdev inline namespace v1 {
62349e08aacStbbdev     using detail::d1::parallel_scan;
62449e08aacStbbdev     using detail::d1::pre_scan_tag;
62549e08aacStbbdev     using detail::d1::final_scan_tag;
62649e08aacStbbdev } // namespace v1
62749e08aacStbbdev 
62849e08aacStbbdev } // namespace tbb
62949e08aacStbbdev 
63049e08aacStbbdev #endif /* __TBB_parallel_scan_H */
631