xref: /oneTBB/src/tbb/task.cpp (revision 51c0b2f7)
1*51c0b2f7Stbbdev /*
2*51c0b2f7Stbbdev     Copyright (c) 2005-2020 Intel Corporation
3*51c0b2f7Stbbdev 
4*51c0b2f7Stbbdev     Licensed under the Apache License, Version 2.0 (the "License");
5*51c0b2f7Stbbdev     you may not use this file except in compliance with the License.
6*51c0b2f7Stbbdev     You may obtain a copy of the License at
7*51c0b2f7Stbbdev 
8*51c0b2f7Stbbdev         http://www.apache.org/licenses/LICENSE-2.0
9*51c0b2f7Stbbdev 
10*51c0b2f7Stbbdev     Unless required by applicable law or agreed to in writing, software
11*51c0b2f7Stbbdev     distributed under the License is distributed on an "AS IS" BASIS,
12*51c0b2f7Stbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*51c0b2f7Stbbdev     See the License for the specific language governing permissions and
14*51c0b2f7Stbbdev     limitations under the License.
15*51c0b2f7Stbbdev */
16*51c0b2f7Stbbdev 
17*51c0b2f7Stbbdev // Do not include task.h directly. Use scheduler_common.h instead
18*51c0b2f7Stbbdev #include "scheduler_common.h"
19*51c0b2f7Stbbdev #include "governor.h"
20*51c0b2f7Stbbdev #include "arena.h"
21*51c0b2f7Stbbdev #include "thread_data.h"
22*51c0b2f7Stbbdev #include "task_dispatcher.h"
23*51c0b2f7Stbbdev #include "waiters.h"
24*51c0b2f7Stbbdev #include "itt_notify.h"
25*51c0b2f7Stbbdev 
26*51c0b2f7Stbbdev #include "tbb/detail/_task.h"
27*51c0b2f7Stbbdev #include "tbb/partitioner.h"
28*51c0b2f7Stbbdev #include "tbb/task.h"
29*51c0b2f7Stbbdev 
30*51c0b2f7Stbbdev #include <cstring>
31*51c0b2f7Stbbdev 
32*51c0b2f7Stbbdev namespace tbb {
33*51c0b2f7Stbbdev namespace detail {
34*51c0b2f7Stbbdev 
35*51c0b2f7Stbbdev namespace d1 {
36*51c0b2f7Stbbdev 
37*51c0b2f7Stbbdev bool wait_context::is_locked() {
38*51c0b2f7Stbbdev     return m_ref_count.load(std::memory_order_relaxed) & lock_flag;
39*51c0b2f7Stbbdev }
40*51c0b2f7Stbbdev 
41*51c0b2f7Stbbdev void wait_context::lock() {
42*51c0b2f7Stbbdev     atomic_backoff backoff;
43*51c0b2f7Stbbdev 
44*51c0b2f7Stbbdev     auto try_lock = [&] { return !(m_ref_count.fetch_or(lock_flag) & lock_flag); };
45*51c0b2f7Stbbdev 
46*51c0b2f7Stbbdev     // While is_locked return true try_lock is not invoked
47*51c0b2f7Stbbdev     while (is_locked() || !try_lock()) {
48*51c0b2f7Stbbdev         backoff.pause();
49*51c0b2f7Stbbdev     }
50*51c0b2f7Stbbdev }
51*51c0b2f7Stbbdev 
52*51c0b2f7Stbbdev void wait_context::unlock() {
53*51c0b2f7Stbbdev     __TBB_ASSERT(is_locked(), NULL);
54*51c0b2f7Stbbdev     m_ref_count.fetch_and(~lock_flag);
55*51c0b2f7Stbbdev }
56*51c0b2f7Stbbdev 
57*51c0b2f7Stbbdev bool wait_context::publish_wait_list() {
58*51c0b2f7Stbbdev     // Try to add waiter_flag to the ref_counter
59*51c0b2f7Stbbdev     // Important : This function should never add waiter_flag if work is done otherwise waiter_flag will be never removed
60*51c0b2f7Stbbdev 
61*51c0b2f7Stbbdev     auto expected = m_ref_count.load(std::memory_order_relaxed);
62*51c0b2f7Stbbdev     __TBB_ASSERT(is_locked() || m_version_and_traits == 0, NULL);
63*51c0b2f7Stbbdev 
64*51c0b2f7Stbbdev     while (!(expected & waiter_flag) && continue_execution()) {
65*51c0b2f7Stbbdev         if (m_ref_count.compare_exchange_strong(expected, expected | waiter_flag)) {
66*51c0b2f7Stbbdev             __TBB_ASSERT(!(expected & waiter_flag), NULL);
67*51c0b2f7Stbbdev             expected |= waiter_flag;
68*51c0b2f7Stbbdev             break;
69*51c0b2f7Stbbdev         }
70*51c0b2f7Stbbdev     }
71*51c0b2f7Stbbdev 
72*51c0b2f7Stbbdev     // There is waiter_flag in ref_count
73*51c0b2f7Stbbdev     return expected & waiter_flag;
74*51c0b2f7Stbbdev }
75*51c0b2f7Stbbdev 
76*51c0b2f7Stbbdev void wait_context::unregister_waiter(r1::wait_node& node) {
77*51c0b2f7Stbbdev     lock_guard lock(*this);
78*51c0b2f7Stbbdev 
79*51c0b2f7Stbbdev     if (m_wait_head != nullptr) {
80*51c0b2f7Stbbdev         if (m_wait_head == &node) {
81*51c0b2f7Stbbdev             m_wait_head = node.my_next;
82*51c0b2f7Stbbdev         }
83*51c0b2f7Stbbdev         node.unlink();
84*51c0b2f7Stbbdev     }
85*51c0b2f7Stbbdev }
86*51c0b2f7Stbbdev 
87*51c0b2f7Stbbdev void wait_context::notify_waiters() {
88*51c0b2f7Stbbdev     lock_guard lock(*this);
89*51c0b2f7Stbbdev 
90*51c0b2f7Stbbdev     if (m_wait_head != nullptr) {
91*51c0b2f7Stbbdev         m_wait_head->notify_all(*this);
92*51c0b2f7Stbbdev         m_wait_head = nullptr;
93*51c0b2f7Stbbdev     }
94*51c0b2f7Stbbdev 
95*51c0b2f7Stbbdev     m_ref_count.store(m_ref_count.load(std::memory_order_relaxed) & ~waiter_flag, std::memory_order_relaxed);
96*51c0b2f7Stbbdev }
97*51c0b2f7Stbbdev 
98*51c0b2f7Stbbdev } // namespace d1
99*51c0b2f7Stbbdev 
100*51c0b2f7Stbbdev namespace r1 {
101*51c0b2f7Stbbdev 
102*51c0b2f7Stbbdev //------------------------------------------------------------------------
103*51c0b2f7Stbbdev // resumable tasks
104*51c0b2f7Stbbdev //------------------------------------------------------------------------
105*51c0b2f7Stbbdev #if __TBB_RESUMABLE_TASKS
106*51c0b2f7Stbbdev 
107*51c0b2f7Stbbdev void suspend(suspend_callback_type suspend_callback, void* user_callback) {
108*51c0b2f7Stbbdev     thread_data& td = *governor::get_thread_data();
109*51c0b2f7Stbbdev     td.my_task_dispatcher->suspend(suspend_callback, user_callback);
110*51c0b2f7Stbbdev     // Do not access td after suspend.
111*51c0b2f7Stbbdev }
112*51c0b2f7Stbbdev 
113*51c0b2f7Stbbdev void resume(suspend_point_type* sp) {
114*51c0b2f7Stbbdev     assert_pointers_valid(sp, sp->m_arena);
115*51c0b2f7Stbbdev     task_dispatcher& task_disp = sp->m_resume_task.m_target;
116*51c0b2f7Stbbdev     __TBB_ASSERT(task_disp.m_thread_data == nullptr, nullptr);
117*51c0b2f7Stbbdev 
118*51c0b2f7Stbbdev     // TODO: remove this work-around
119*51c0b2f7Stbbdev     // Prolong the arena's lifetime while all coroutines are alive
120*51c0b2f7Stbbdev     // (otherwise the arena can be destroyed while some tasks are suspended).
121*51c0b2f7Stbbdev     arena& a = *sp->m_arena;
122*51c0b2f7Stbbdev     a.my_references += arena::ref_external;
123*51c0b2f7Stbbdev 
124*51c0b2f7Stbbdev     if (task_disp.m_properties.critical_task_allowed) {
125*51c0b2f7Stbbdev         // The target is not in the process of executing critical task, so the resume task is not critical.
126*51c0b2f7Stbbdev         a.my_resume_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random));
127*51c0b2f7Stbbdev     } else {
128*51c0b2f7Stbbdev #if __TBB_PREVIEW_CRITICAL_TASKS
129*51c0b2f7Stbbdev         // The target is in the process of executing critical task, so the resume task is critical.
130*51c0b2f7Stbbdev         a.my_critical_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random));
131*51c0b2f7Stbbdev #endif
132*51c0b2f7Stbbdev     }
133*51c0b2f7Stbbdev 
134*51c0b2f7Stbbdev     // Do not access target after that point.
135*51c0b2f7Stbbdev     a.advertise_new_work<arena::wakeup>();
136*51c0b2f7Stbbdev 
137*51c0b2f7Stbbdev     // Release our reference to my_arena.
138*51c0b2f7Stbbdev     a.on_thread_leaving<arena::ref_external>();
139*51c0b2f7Stbbdev }
140*51c0b2f7Stbbdev 
141*51c0b2f7Stbbdev suspend_point_type* current_suspend_point() {
142*51c0b2f7Stbbdev     thread_data& td = *governor::get_thread_data();
143*51c0b2f7Stbbdev     return td.my_task_dispatcher->get_suspend_point();
144*51c0b2f7Stbbdev }
145*51c0b2f7Stbbdev 
146*51c0b2f7Stbbdev static task_dispatcher& create_coroutine(thread_data& td) {
147*51c0b2f7Stbbdev     // We may have some task dispatchers cached
148*51c0b2f7Stbbdev     task_dispatcher* task_disp = td.my_arena->my_co_cache.pop();
149*51c0b2f7Stbbdev     if (!task_disp) {
150*51c0b2f7Stbbdev         void* ptr = cache_aligned_allocate(sizeof(task_dispatcher));
151*51c0b2f7Stbbdev         task_disp = new(ptr) task_dispatcher(td.my_arena);
152*51c0b2f7Stbbdev         task_disp->init_suspend_point(td.my_arena, td.my_arena->my_market->worker_stack_size());
153*51c0b2f7Stbbdev     }
154*51c0b2f7Stbbdev     // Prolong the arena's lifetime until all coroutines is alive
155*51c0b2f7Stbbdev     // (otherwise the arena can be destroyed while some tasks are suspended).
156*51c0b2f7Stbbdev     // TODO: consider behavior if there are more than 4K external references.
157*51c0b2f7Stbbdev     td.my_arena->my_references += arena::ref_external;
158*51c0b2f7Stbbdev     return *task_disp;
159*51c0b2f7Stbbdev }
160*51c0b2f7Stbbdev 
161*51c0b2f7Stbbdev void task_dispatcher::suspend(suspend_callback_type suspend_callback, void* user_callback) {
162*51c0b2f7Stbbdev     __TBB_ASSERT(suspend_callback != nullptr, nullptr);
163*51c0b2f7Stbbdev     __TBB_ASSERT(user_callback != nullptr, nullptr);
164*51c0b2f7Stbbdev     __TBB_ASSERT(m_thread_data != nullptr, nullptr);
165*51c0b2f7Stbbdev 
166*51c0b2f7Stbbdev     arena_slot* slot = m_thread_data->my_arena_slot;
167*51c0b2f7Stbbdev     __TBB_ASSERT(slot != nullptr, nullptr);
168*51c0b2f7Stbbdev 
169*51c0b2f7Stbbdev     task_dispatcher& default_task_disp = slot->default_task_dispatcher();
170*51c0b2f7Stbbdev     // TODO: simplify the next line, e.g. is_task_dispatcher_recalled( task_dispatcher& )
171*51c0b2f7Stbbdev     bool is_recalled = default_task_disp.get_suspend_point()->m_is_owner_recalled.load(std::memory_order_acquire);
172*51c0b2f7Stbbdev     task_dispatcher& target = is_recalled ? default_task_disp : create_coroutine(*m_thread_data);
173*51c0b2f7Stbbdev 
174*51c0b2f7Stbbdev     thread_data::suspend_callback_wrapper callback = { suspend_callback, user_callback, get_suspend_point() };
175*51c0b2f7Stbbdev     m_thread_data->set_post_resume_action(thread_data::post_resume_action::callback, &callback);
176*51c0b2f7Stbbdev     resume(target);
177*51c0b2f7Stbbdev 
178*51c0b2f7Stbbdev     if (m_properties.outermost) {
179*51c0b2f7Stbbdev         recall_point();
180*51c0b2f7Stbbdev     }
181*51c0b2f7Stbbdev }
182*51c0b2f7Stbbdev 
183*51c0b2f7Stbbdev void task_dispatcher::resume(task_dispatcher& target) {
184*51c0b2f7Stbbdev     // Do not create non-trivial objects on the stack of this function. They might never be destroyed
185*51c0b2f7Stbbdev     {
186*51c0b2f7Stbbdev         thread_data* td = m_thread_data;
187*51c0b2f7Stbbdev         __TBB_ASSERT(&target != this, "We cannot resume to ourself");
188*51c0b2f7Stbbdev         __TBB_ASSERT(td != nullptr, "This task dispatcher must be attach to a thread data");
189*51c0b2f7Stbbdev         __TBB_ASSERT(td->my_task_dispatcher == this, "Thread data must be attached to this task dispatcher");
190*51c0b2f7Stbbdev         __TBB_ASSERT(td->my_post_resume_action != thread_data::post_resume_action::none, "The post resume action must be set");
191*51c0b2f7Stbbdev         __TBB_ASSERT(td->my_post_resume_arg, "The post resume action must have an argument");
192*51c0b2f7Stbbdev 
193*51c0b2f7Stbbdev         // Change the task dispatcher
194*51c0b2f7Stbbdev         td->detach_task_dispatcher();
195*51c0b2f7Stbbdev         td->attach_task_dispatcher(target);
196*51c0b2f7Stbbdev     }
197*51c0b2f7Stbbdev     __TBB_ASSERT(m_suspend_point != nullptr, "Suspend point must be created");
198*51c0b2f7Stbbdev     __TBB_ASSERT(target.m_suspend_point != nullptr, "Suspend point must be created");
199*51c0b2f7Stbbdev     // Swap to the target coroutine.
200*51c0b2f7Stbbdev     m_suspend_point->m_co_context.resume(target.m_suspend_point->m_co_context);
201*51c0b2f7Stbbdev     // Pay attention that m_thread_data can be changed after resume
202*51c0b2f7Stbbdev     {
203*51c0b2f7Stbbdev         thread_data* td = m_thread_data;
204*51c0b2f7Stbbdev         __TBB_ASSERT(td != nullptr, "This task dispatcher must be attach to a thread data");
205*51c0b2f7Stbbdev         __TBB_ASSERT(td->my_task_dispatcher == this, "Thread data must be attached to this task dispatcher");
206*51c0b2f7Stbbdev         td->do_post_resume_action();
207*51c0b2f7Stbbdev 
208*51c0b2f7Stbbdev         // Remove the recall flag if the thread in its original task dispatcher
209*51c0b2f7Stbbdev         arena_slot* slot = td->my_arena_slot;
210*51c0b2f7Stbbdev         __TBB_ASSERT(slot != nullptr, nullptr);
211*51c0b2f7Stbbdev         if (this == slot->my_default_task_dispatcher) {
212*51c0b2f7Stbbdev             __TBB_ASSERT(m_suspend_point != nullptr, nullptr);
213*51c0b2f7Stbbdev             m_suspend_point->m_is_owner_recalled.store(false, std::memory_order_relaxed);
214*51c0b2f7Stbbdev         }
215*51c0b2f7Stbbdev     }
216*51c0b2f7Stbbdev }
217*51c0b2f7Stbbdev 
218*51c0b2f7Stbbdev void thread_data::do_post_resume_action() {
219*51c0b2f7Stbbdev     __TBB_ASSERT(my_post_resume_action != thread_data::post_resume_action::none, "The post resume action must be set");
220*51c0b2f7Stbbdev     __TBB_ASSERT(my_post_resume_arg, "The post resume action must have an argument");
221*51c0b2f7Stbbdev 
222*51c0b2f7Stbbdev     switch (my_post_resume_action) {
223*51c0b2f7Stbbdev     case post_resume_action::register_waiter:
224*51c0b2f7Stbbdev     {
225*51c0b2f7Stbbdev         auto& data = *static_cast<thread_data::register_waiter_data*>(my_post_resume_arg);
226*51c0b2f7Stbbdev 
227*51c0b2f7Stbbdev         // Support of backward compatibility
228*51c0b2f7Stbbdev         if (data.wo.m_version_and_traits == 0) {
229*51c0b2f7Stbbdev             data.wo.m_wait_head = reinterpret_cast<wait_node*>(data.node.my_suspend_point);
230*51c0b2f7Stbbdev             if (!data.wo.publish_wait_list()) {
231*51c0b2f7Stbbdev                 r1::resume(data.node.my_suspend_point);
232*51c0b2f7Stbbdev             }
233*51c0b2f7Stbbdev             break;
234*51c0b2f7Stbbdev         }
235*51c0b2f7Stbbdev 
236*51c0b2f7Stbbdev         auto wait_condition = [&data] { return data.wo.continue_execution(); };
237*51c0b2f7Stbbdev         if (!data.wo.try_register_waiter(data.node, wait_condition)) {
238*51c0b2f7Stbbdev             r1::resume(data.node.my_suspend_point);
239*51c0b2f7Stbbdev         }
240*51c0b2f7Stbbdev 
241*51c0b2f7Stbbdev         break;
242*51c0b2f7Stbbdev     }
243*51c0b2f7Stbbdev     case post_resume_action::callback:
244*51c0b2f7Stbbdev     {
245*51c0b2f7Stbbdev         suspend_callback_wrapper callback = *static_cast<suspend_callback_wrapper*>(my_post_resume_arg);
246*51c0b2f7Stbbdev         callback();
247*51c0b2f7Stbbdev         break;
248*51c0b2f7Stbbdev     }
249*51c0b2f7Stbbdev     case post_resume_action::cleanup:
250*51c0b2f7Stbbdev     {
251*51c0b2f7Stbbdev         task_dispatcher* to_cleanup = static_cast<task_dispatcher*>(my_post_resume_arg);
252*51c0b2f7Stbbdev         // Release coroutine's reference to my_arena.
253*51c0b2f7Stbbdev         my_arena->on_thread_leaving<arena::ref_external>();
254*51c0b2f7Stbbdev         // Cache the coroutine for possible later re-usage
255*51c0b2f7Stbbdev         my_arena->my_co_cache.push(to_cleanup);
256*51c0b2f7Stbbdev         break;
257*51c0b2f7Stbbdev     }
258*51c0b2f7Stbbdev     case post_resume_action::notify:
259*51c0b2f7Stbbdev     {
260*51c0b2f7Stbbdev         std::atomic<bool>& owner_recall_flag = *static_cast<std::atomic<bool>*>(my_post_resume_arg);
261*51c0b2f7Stbbdev         owner_recall_flag.store(true, std::memory_order_release);
262*51c0b2f7Stbbdev         // Do not access recall_flag because it can be destroyed after the notification.
263*51c0b2f7Stbbdev         break;
264*51c0b2f7Stbbdev     }
265*51c0b2f7Stbbdev     default:
266*51c0b2f7Stbbdev         __TBB_ASSERT(false, "Unknown post resume action");
267*51c0b2f7Stbbdev     }
268*51c0b2f7Stbbdev 
269*51c0b2f7Stbbdev     my_post_resume_action = post_resume_action::none;
270*51c0b2f7Stbbdev     my_post_resume_arg = nullptr;
271*51c0b2f7Stbbdev }
272*51c0b2f7Stbbdev 
273*51c0b2f7Stbbdev #else
274*51c0b2f7Stbbdev 
275*51c0b2f7Stbbdev void suspend(suspend_callback_type, void*) {
276*51c0b2f7Stbbdev     __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform");
277*51c0b2f7Stbbdev }
278*51c0b2f7Stbbdev 
279*51c0b2f7Stbbdev void resume(suspend_point_type*) {
280*51c0b2f7Stbbdev     __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform");
281*51c0b2f7Stbbdev }
282*51c0b2f7Stbbdev 
283*51c0b2f7Stbbdev suspend_point_type* current_suspend_point() {
284*51c0b2f7Stbbdev     __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform");
285*51c0b2f7Stbbdev     return nullptr;
286*51c0b2f7Stbbdev }
287*51c0b2f7Stbbdev 
288*51c0b2f7Stbbdev #endif /* __TBB_RESUMABLE_TASKS */
289*51c0b2f7Stbbdev 
290*51c0b2f7Stbbdev void notify_waiters(d1::wait_context& wc) {
291*51c0b2f7Stbbdev     __TBB_ASSERT(wc.m_version_and_traits > 0, NULL);
292*51c0b2f7Stbbdev 
293*51c0b2f7Stbbdev     wc.notify_waiters();
294*51c0b2f7Stbbdev }
295*51c0b2f7Stbbdev 
296*51c0b2f7Stbbdev } // namespace r1
297*51c0b2f7Stbbdev } // namespace detail
298*51c0b2f7Stbbdev } // namespace tbb
299*51c0b2f7Stbbdev 
300