xref: /oneTBB/src/tbb/task.cpp (revision 51c0b2f7)
1 /*
2     Copyright (c) 2005-2020 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 // Do not include task.h directly. Use scheduler_common.h instead
18 #include "scheduler_common.h"
19 #include "governor.h"
20 #include "arena.h"
21 #include "thread_data.h"
22 #include "task_dispatcher.h"
23 #include "waiters.h"
24 #include "itt_notify.h"
25 
26 #include "tbb/detail/_task.h"
27 #include "tbb/partitioner.h"
28 #include "tbb/task.h"
29 
30 #include <cstring>
31 
32 namespace tbb {
33 namespace detail {
34 
35 namespace d1 {
36 
37 bool wait_context::is_locked() {
38     return m_ref_count.load(std::memory_order_relaxed) & lock_flag;
39 }
40 
41 void wait_context::lock() {
42     atomic_backoff backoff;
43 
44     auto try_lock = [&] { return !(m_ref_count.fetch_or(lock_flag) & lock_flag); };
45 
46     // While is_locked return true try_lock is not invoked
47     while (is_locked() || !try_lock()) {
48         backoff.pause();
49     }
50 }
51 
52 void wait_context::unlock() {
53     __TBB_ASSERT(is_locked(), NULL);
54     m_ref_count.fetch_and(~lock_flag);
55 }
56 
57 bool wait_context::publish_wait_list() {
58     // Try to add waiter_flag to the ref_counter
59     // Important : This function should never add waiter_flag if work is done otherwise waiter_flag will be never removed
60 
61     auto expected = m_ref_count.load(std::memory_order_relaxed);
62     __TBB_ASSERT(is_locked() || m_version_and_traits == 0, NULL);
63 
64     while (!(expected & waiter_flag) && continue_execution()) {
65         if (m_ref_count.compare_exchange_strong(expected, expected | waiter_flag)) {
66             __TBB_ASSERT(!(expected & waiter_flag), NULL);
67             expected |= waiter_flag;
68             break;
69         }
70     }
71 
72     // There is waiter_flag in ref_count
73     return expected & waiter_flag;
74 }
75 
76 void wait_context::unregister_waiter(r1::wait_node& node) {
77     lock_guard lock(*this);
78 
79     if (m_wait_head != nullptr) {
80         if (m_wait_head == &node) {
81             m_wait_head = node.my_next;
82         }
83         node.unlink();
84     }
85 }
86 
87 void wait_context::notify_waiters() {
88     lock_guard lock(*this);
89 
90     if (m_wait_head != nullptr) {
91         m_wait_head->notify_all(*this);
92         m_wait_head = nullptr;
93     }
94 
95     m_ref_count.store(m_ref_count.load(std::memory_order_relaxed) & ~waiter_flag, std::memory_order_relaxed);
96 }
97 
98 } // namespace d1
99 
100 namespace r1 {
101 
102 //------------------------------------------------------------------------
103 // resumable tasks
104 //------------------------------------------------------------------------
105 #if __TBB_RESUMABLE_TASKS
106 
107 void suspend(suspend_callback_type suspend_callback, void* user_callback) {
108     thread_data& td = *governor::get_thread_data();
109     td.my_task_dispatcher->suspend(suspend_callback, user_callback);
110     // Do not access td after suspend.
111 }
112 
113 void resume(suspend_point_type* sp) {
114     assert_pointers_valid(sp, sp->m_arena);
115     task_dispatcher& task_disp = sp->m_resume_task.m_target;
116     __TBB_ASSERT(task_disp.m_thread_data == nullptr, nullptr);
117 
118     // TODO: remove this work-around
119     // Prolong the arena's lifetime while all coroutines are alive
120     // (otherwise the arena can be destroyed while some tasks are suspended).
121     arena& a = *sp->m_arena;
122     a.my_references += arena::ref_external;
123 
124     if (task_disp.m_properties.critical_task_allowed) {
125         // The target is not in the process of executing critical task, so the resume task is not critical.
126         a.my_resume_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random));
127     } else {
128 #if __TBB_PREVIEW_CRITICAL_TASKS
129         // The target is in the process of executing critical task, so the resume task is critical.
130         a.my_critical_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random));
131 #endif
132     }
133 
134     // Do not access target after that point.
135     a.advertise_new_work<arena::wakeup>();
136 
137     // Release our reference to my_arena.
138     a.on_thread_leaving<arena::ref_external>();
139 }
140 
141 suspend_point_type* current_suspend_point() {
142     thread_data& td = *governor::get_thread_data();
143     return td.my_task_dispatcher->get_suspend_point();
144 }
145 
146 static task_dispatcher& create_coroutine(thread_data& td) {
147     // We may have some task dispatchers cached
148     task_dispatcher* task_disp = td.my_arena->my_co_cache.pop();
149     if (!task_disp) {
150         void* ptr = cache_aligned_allocate(sizeof(task_dispatcher));
151         task_disp = new(ptr) task_dispatcher(td.my_arena);
152         task_disp->init_suspend_point(td.my_arena, td.my_arena->my_market->worker_stack_size());
153     }
154     // Prolong the arena's lifetime until all coroutines is alive
155     // (otherwise the arena can be destroyed while some tasks are suspended).
156     // TODO: consider behavior if there are more than 4K external references.
157     td.my_arena->my_references += arena::ref_external;
158     return *task_disp;
159 }
160 
161 void task_dispatcher::suspend(suspend_callback_type suspend_callback, void* user_callback) {
162     __TBB_ASSERT(suspend_callback != nullptr, nullptr);
163     __TBB_ASSERT(user_callback != nullptr, nullptr);
164     __TBB_ASSERT(m_thread_data != nullptr, nullptr);
165 
166     arena_slot* slot = m_thread_data->my_arena_slot;
167     __TBB_ASSERT(slot != nullptr, nullptr);
168 
169     task_dispatcher& default_task_disp = slot->default_task_dispatcher();
170     // TODO: simplify the next line, e.g. is_task_dispatcher_recalled( task_dispatcher& )
171     bool is_recalled = default_task_disp.get_suspend_point()->m_is_owner_recalled.load(std::memory_order_acquire);
172     task_dispatcher& target = is_recalled ? default_task_disp : create_coroutine(*m_thread_data);
173 
174     thread_data::suspend_callback_wrapper callback = { suspend_callback, user_callback, get_suspend_point() };
175     m_thread_data->set_post_resume_action(thread_data::post_resume_action::callback, &callback);
176     resume(target);
177 
178     if (m_properties.outermost) {
179         recall_point();
180     }
181 }
182 
183 void task_dispatcher::resume(task_dispatcher& target) {
184     // Do not create non-trivial objects on the stack of this function. They might never be destroyed
185     {
186         thread_data* td = m_thread_data;
187         __TBB_ASSERT(&target != this, "We cannot resume to ourself");
188         __TBB_ASSERT(td != nullptr, "This task dispatcher must be attach to a thread data");
189         __TBB_ASSERT(td->my_task_dispatcher == this, "Thread data must be attached to this task dispatcher");
190         __TBB_ASSERT(td->my_post_resume_action != thread_data::post_resume_action::none, "The post resume action must be set");
191         __TBB_ASSERT(td->my_post_resume_arg, "The post resume action must have an argument");
192 
193         // Change the task dispatcher
194         td->detach_task_dispatcher();
195         td->attach_task_dispatcher(target);
196     }
197     __TBB_ASSERT(m_suspend_point != nullptr, "Suspend point must be created");
198     __TBB_ASSERT(target.m_suspend_point != nullptr, "Suspend point must be created");
199     // Swap to the target coroutine.
200     m_suspend_point->m_co_context.resume(target.m_suspend_point->m_co_context);
201     // Pay attention that m_thread_data can be changed after resume
202     {
203         thread_data* td = m_thread_data;
204         __TBB_ASSERT(td != nullptr, "This task dispatcher must be attach to a thread data");
205         __TBB_ASSERT(td->my_task_dispatcher == this, "Thread data must be attached to this task dispatcher");
206         td->do_post_resume_action();
207 
208         // Remove the recall flag if the thread in its original task dispatcher
209         arena_slot* slot = td->my_arena_slot;
210         __TBB_ASSERT(slot != nullptr, nullptr);
211         if (this == slot->my_default_task_dispatcher) {
212             __TBB_ASSERT(m_suspend_point != nullptr, nullptr);
213             m_suspend_point->m_is_owner_recalled.store(false, std::memory_order_relaxed);
214         }
215     }
216 }
217 
218 void thread_data::do_post_resume_action() {
219     __TBB_ASSERT(my_post_resume_action != thread_data::post_resume_action::none, "The post resume action must be set");
220     __TBB_ASSERT(my_post_resume_arg, "The post resume action must have an argument");
221 
222     switch (my_post_resume_action) {
223     case post_resume_action::register_waiter:
224     {
225         auto& data = *static_cast<thread_data::register_waiter_data*>(my_post_resume_arg);
226 
227         // Support of backward compatibility
228         if (data.wo.m_version_and_traits == 0) {
229             data.wo.m_wait_head = reinterpret_cast<wait_node*>(data.node.my_suspend_point);
230             if (!data.wo.publish_wait_list()) {
231                 r1::resume(data.node.my_suspend_point);
232             }
233             break;
234         }
235 
236         auto wait_condition = [&data] { return data.wo.continue_execution(); };
237         if (!data.wo.try_register_waiter(data.node, wait_condition)) {
238             r1::resume(data.node.my_suspend_point);
239         }
240 
241         break;
242     }
243     case post_resume_action::callback:
244     {
245         suspend_callback_wrapper callback = *static_cast<suspend_callback_wrapper*>(my_post_resume_arg);
246         callback();
247         break;
248     }
249     case post_resume_action::cleanup:
250     {
251         task_dispatcher* to_cleanup = static_cast<task_dispatcher*>(my_post_resume_arg);
252         // Release coroutine's reference to my_arena.
253         my_arena->on_thread_leaving<arena::ref_external>();
254         // Cache the coroutine for possible later re-usage
255         my_arena->my_co_cache.push(to_cleanup);
256         break;
257     }
258     case post_resume_action::notify:
259     {
260         std::atomic<bool>& owner_recall_flag = *static_cast<std::atomic<bool>*>(my_post_resume_arg);
261         owner_recall_flag.store(true, std::memory_order_release);
262         // Do not access recall_flag because it can be destroyed after the notification.
263         break;
264     }
265     default:
266         __TBB_ASSERT(false, "Unknown post resume action");
267     }
268 
269     my_post_resume_action = post_resume_action::none;
270     my_post_resume_arg = nullptr;
271 }
272 
273 #else
274 
275 void suspend(suspend_callback_type, void*) {
276     __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform");
277 }
278 
279 void resume(suspend_point_type*) {
280     __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform");
281 }
282 
283 suspend_point_type* current_suspend_point() {
284     __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform");
285     return nullptr;
286 }
287 
288 #endif /* __TBB_RESUMABLE_TASKS */
289 
290 void notify_waiters(d1::wait_context& wc) {
291     __TBB_ASSERT(wc.m_version_and_traits > 0, NULL);
292 
293     wc.notify_waiters();
294 }
295 
296 } // namespace r1
297 } // namespace detail
298 } // namespace tbb
299 
300