xref: /oneTBB/src/tbb/task.cpp (revision 49e08aac)
1 /*
2     Copyright (c) 2005-2020 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 // Do not include task.h directly. Use scheduler_common.h instead
18 #include "scheduler_common.h"
19 #include "governor.h"
20 #include "arena.h"
21 #include "thread_data.h"
22 #include "task_dispatcher.h"
23 #include "waiters.h"
24 #include "itt_notify.h"
25 
26 #include "oneapi/tbb/detail/_task.h"
27 #include "oneapi/tbb/partitioner.h"
28 #include "oneapi/tbb/task.h"
29 
30 #include <cstring>
31 
32 namespace tbb {
33 namespace detail {
34 namespace r1 {
35 
36 //------------------------------------------------------------------------
37 // resumable tasks
38 //------------------------------------------------------------------------
39 #if __TBB_RESUMABLE_TASKS
40 
41 void suspend(suspend_callback_type suspend_callback, void* user_callback) {
42     thread_data& td = *governor::get_thread_data();
43     td.my_task_dispatcher->suspend(suspend_callback, user_callback);
44     // Do not access td after suspend.
45 }
46 
47 void resume(suspend_point_type* sp) {
48     assert_pointers_valid(sp, sp->m_arena);
49     task_dispatcher& task_disp = sp->m_resume_task.m_target;
50     __TBB_ASSERT(task_disp.m_thread_data == nullptr, nullptr);
51 
52     // TODO: remove this work-around
53     // Prolong the arena's lifetime while all coroutines are alive
54     // (otherwise the arena can be destroyed while some tasks are suspended).
55     arena& a = *sp->m_arena;
56     a.my_references += arena::ref_external;
57 
58     if (task_disp.m_properties.critical_task_allowed) {
59         // The target is not in the process of executing critical task, so the resume task is not critical.
60         a.my_resume_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random));
61     } else {
62 #if __TBB_PREVIEW_CRITICAL_TASKS
63         // The target is in the process of executing critical task, so the resume task is critical.
64         a.my_critical_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random));
65 #endif
66     }
67 
68     // Do not access target after that point.
69     a.advertise_new_work<arena::wakeup>();
70 
71     // Release our reference to my_arena.
72     a.on_thread_leaving<arena::ref_external>();
73 }
74 
75 suspend_point_type* current_suspend_point() {
76     thread_data& td = *governor::get_thread_data();
77     return td.my_task_dispatcher->get_suspend_point();
78 }
79 
80 static task_dispatcher& create_coroutine(thread_data& td) {
81     // We may have some task dispatchers cached
82     task_dispatcher* task_disp = td.my_arena->my_co_cache.pop();
83     if (!task_disp) {
84         void* ptr = cache_aligned_allocate(sizeof(task_dispatcher));
85         task_disp = new(ptr) task_dispatcher(td.my_arena);
86         task_disp->init_suspend_point(td.my_arena, td.my_arena->my_market->worker_stack_size());
87     }
88     // Prolong the arena's lifetime until all coroutines is alive
89     // (otherwise the arena can be destroyed while some tasks are suspended).
90     // TODO: consider behavior if there are more than 4K external references.
91     td.my_arena->my_references += arena::ref_external;
92     return *task_disp;
93 }
94 
95 void task_dispatcher::suspend(suspend_callback_type suspend_callback, void* user_callback) {
96     __TBB_ASSERT(suspend_callback != nullptr, nullptr);
97     __TBB_ASSERT(user_callback != nullptr, nullptr);
98     __TBB_ASSERT(m_thread_data != nullptr, nullptr);
99 
100     arena_slot* slot = m_thread_data->my_arena_slot;
101     __TBB_ASSERT(slot != nullptr, nullptr);
102 
103     task_dispatcher& default_task_disp = slot->default_task_dispatcher();
104     // TODO: simplify the next line, e.g. is_task_dispatcher_recalled( task_dispatcher& )
105     bool is_recalled = default_task_disp.get_suspend_point()->m_is_owner_recalled.load(std::memory_order_acquire);
106     task_dispatcher& target = is_recalled ? default_task_disp : create_coroutine(*m_thread_data);
107 
108     thread_data::suspend_callback_wrapper callback = { suspend_callback, user_callback, get_suspend_point() };
109     m_thread_data->set_post_resume_action(thread_data::post_resume_action::callback, &callback);
110     resume(target);
111 
112     if (m_properties.outermost) {
113         recall_point();
114     }
115 }
116 
117 void task_dispatcher::resume(task_dispatcher& target) {
118     // Do not create non-trivial objects on the stack of this function. They might never be destroyed
119     {
120         thread_data* td = m_thread_data;
121         __TBB_ASSERT(&target != this, "We cannot resume to ourself");
122         __TBB_ASSERT(td != nullptr, "This task dispatcher must be attach to a thread data");
123         __TBB_ASSERT(td->my_task_dispatcher == this, "Thread data must be attached to this task dispatcher");
124         __TBB_ASSERT(td->my_post_resume_action != thread_data::post_resume_action::none, "The post resume action must be set");
125         __TBB_ASSERT(td->my_post_resume_arg, "The post resume action must have an argument");
126 
127         // Change the task dispatcher
128         td->detach_task_dispatcher();
129         td->attach_task_dispatcher(target);
130     }
131     __TBB_ASSERT(m_suspend_point != nullptr, "Suspend point must be created");
132     __TBB_ASSERT(target.m_suspend_point != nullptr, "Suspend point must be created");
133     // Swap to the target coroutine.
134     m_suspend_point->m_co_context.resume(target.m_suspend_point->m_co_context);
135     // Pay attention that m_thread_data can be changed after resume
136     {
137         thread_data* td = m_thread_data;
138         __TBB_ASSERT(td != nullptr, "This task dispatcher must be attach to a thread data");
139         __TBB_ASSERT(td->my_task_dispatcher == this, "Thread data must be attached to this task dispatcher");
140         td->do_post_resume_action();
141 
142         // Remove the recall flag if the thread in its original task dispatcher
143         arena_slot* slot = td->my_arena_slot;
144         __TBB_ASSERT(slot != nullptr, nullptr);
145         if (this == slot->my_default_task_dispatcher) {
146             __TBB_ASSERT(m_suspend_point != nullptr, nullptr);
147             m_suspend_point->m_is_owner_recalled.store(false, std::memory_order_relaxed);
148         }
149     }
150 }
151 
152 void thread_data::do_post_resume_action() {
153     __TBB_ASSERT(my_post_resume_action != thread_data::post_resume_action::none, "The post resume action must be set");
154     __TBB_ASSERT(my_post_resume_arg, "The post resume action must have an argument");
155 
156     switch (my_post_resume_action) {
157     case post_resume_action::register_waiter:
158     {
159         auto& data = *static_cast<thread_data::register_waiter_data*>(my_post_resume_arg);
160         using state = wait_node::node_state;
161         state expected = state::not_ready;
162 
163         // There are three possible situations:
164         // - wait_context has finished => call resume by ourselves
165         // - wait_context::continue_execution() returns true, but CAS fails => call resume by ourselves
166         // - wait_context::continue_execution() returns true, and CAS succeeds => successfully committed to wait list
167         if (!data.wo->continue_execution() ||
168 #if defined(__INTEL_COMPILER) && __INTEL_COMPILER <= 1910
169             !((std::atomic<unsigned>&)data.node.my_ready_flag).compare_exchange_strong((unsigned&)expected, (unsigned)state::ready))
170 #else
171             !data.node.my_ready_flag.compare_exchange_strong(expected, state::ready))
172 #endif
173         {
174             data.node.my_suspend_point->m_arena->my_market->get_wait_list().cancel_wait(data.node);
175             r1::resume(data.node.my_suspend_point);
176         }
177 
178         break;
179     }
180     case post_resume_action::callback:
181     {
182         suspend_callback_wrapper callback = *static_cast<suspend_callback_wrapper*>(my_post_resume_arg);
183         callback();
184         break;
185     }
186     case post_resume_action::cleanup:
187     {
188         task_dispatcher* to_cleanup = static_cast<task_dispatcher*>(my_post_resume_arg);
189         // Release coroutine's reference to my_arena.
190         my_arena->on_thread_leaving<arena::ref_external>();
191         // Cache the coroutine for possible later re-usage
192         my_arena->my_co_cache.push(to_cleanup);
193         break;
194     }
195     case post_resume_action::notify:
196     {
197         std::atomic<bool>& owner_recall_flag = *static_cast<std::atomic<bool>*>(my_post_resume_arg);
198         owner_recall_flag.store(true, std::memory_order_release);
199         // Do not access recall_flag because it can be destroyed after the notification.
200         break;
201     }
202     default:
203         __TBB_ASSERT(false, "Unknown post resume action");
204     }
205 
206     my_post_resume_action = post_resume_action::none;
207     my_post_resume_arg = nullptr;
208 }
209 
210 #else
211 
212 void suspend(suspend_callback_type, void*) {
213     __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform");
214 }
215 
216 void resume(suspend_point_type*) {
217     __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform");
218 }
219 
220 suspend_point_type* current_suspend_point() {
221     __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform");
222     return nullptr;
223 }
224 
225 #endif /* __TBB_RESUMABLE_TASKS */
226 
227 void notify_waiters(std::uintptr_t wait_ctx_tag) {
228     auto is_related_wait_ctx = [&] (extended_context context) {
229         return wait_ctx_tag == context.uniq_ctx;
230     };
231 
232     r1::governor::get_thread_data()->my_arena->my_market->get_wait_list().notify(is_related_wait_ctx);
233 }
234 
235 } // namespace r1
236 } // namespace detail
237 } // namespace tbb
238 
239