xref: /oneTBB/src/tbb/task.cpp (revision ce0d258e)
1 /*
2     Copyright (c) 2005-2022 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 // Do not include task.h directly. Use scheduler_common.h instead
18 #include "scheduler_common.h"
19 #include "governor.h"
20 #include "arena.h"
21 #include "thread_data.h"
22 #include "task_dispatcher.h"
23 #include "waiters.h"
24 #include "itt_notify.h"
25 
26 #include "oneapi/tbb/detail/_task.h"
27 #include "oneapi/tbb/partitioner.h"
28 #include "oneapi/tbb/task.h"
29 
30 #include <cstring>
31 
32 namespace tbb {
33 namespace detail {
34 namespace r1 {
35 
36 //------------------------------------------------------------------------
37 // resumable tasks
38 //------------------------------------------------------------------------
39 #if __TBB_RESUMABLE_TASKS
40 
41 void suspend(suspend_callback_type suspend_callback, void* user_callback) {
42     thread_data& td = *governor::get_thread_data();
43     td.my_task_dispatcher->suspend(suspend_callback, user_callback);
44     // Do not access td after suspend.
45 }
46 
47 void resume(suspend_point_type* sp) {
48     assert_pointers_valid(sp, sp->m_arena);
49     task_dispatcher& task_disp = sp->m_resume_task.m_target;
50 
51     if (sp->try_notify_resume()) {
52         // TODO: remove this work-around
53         // Prolong the arena's lifetime while all coroutines are alive
54         // (otherwise the arena can be destroyed while some tasks are suspended).
55         arena& a = *sp->m_arena;
56         a.my_references += arena::ref_external;
57 
58         if (task_disp.m_properties.critical_task_allowed) {
59             // The target is not in the process of executing critical task, so the resume task is not critical.
60             a.my_resume_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random));
61         } else {
62     #if __TBB_PREVIEW_CRITICAL_TASKS
63             // The target is in the process of executing critical task, so the resume task is critical.
64             a.my_critical_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random));
65     #endif
66         }
67         // Do not access target after that point.
68         a.advertise_new_work<arena::wakeup>();
69         // Release our reference to my_arena.
70         a.on_thread_leaving<arena::ref_external>();
71     }
72 
73 }
74 
75 suspend_point_type* current_suspend_point() {
76     thread_data& td = *governor::get_thread_data();
77     return td.my_task_dispatcher->get_suspend_point();
78 }
79 
80 static task_dispatcher& create_coroutine(thread_data& td) {
81     // We may have some task dispatchers cached
82     task_dispatcher* task_disp = td.my_arena->my_co_cache.pop();
83     if (!task_disp) {
84         void* ptr = cache_aligned_allocate(sizeof(task_dispatcher));
85         task_disp = new(ptr) task_dispatcher(td.my_arena);
86         task_disp->init_suspend_point(td.my_arena, td.my_arena->my_market->worker_stack_size());
87     }
88     // Prolong the arena's lifetime until all coroutines is alive
89     // (otherwise the arena can be destroyed while some tasks are suspended).
90     // TODO: consider behavior if there are more than 4K external references.
91     td.my_arena->my_references += arena::ref_external;
92     return *task_disp;
93 }
94 
95 void task_dispatcher::internal_suspend() {
96     __TBB_ASSERT(m_thread_data != nullptr, nullptr);
97 
98     arena_slot* slot = m_thread_data->my_arena_slot;
99     __TBB_ASSERT(slot != nullptr, nullptr);
100 
101     task_dispatcher& default_task_disp = slot->default_task_dispatcher();
102     // TODO: simplify the next line, e.g. is_task_dispatcher_recalled( task_dispatcher& )
103     bool is_recalled = default_task_disp.get_suspend_point()->m_is_owner_recalled.load(std::memory_order_acquire);
104     task_dispatcher& target = is_recalled ? default_task_disp : create_coroutine(*m_thread_data);
105 
106     resume(target);
107 
108     if (m_properties.outermost) {
109         recall_point();
110     }
111 }
112 
113 void task_dispatcher::suspend(suspend_callback_type suspend_callback, void* user_callback) {
114     __TBB_ASSERT(suspend_callback != nullptr, nullptr);
115     __TBB_ASSERT(user_callback != nullptr, nullptr);
116     suspend_callback(user_callback, get_suspend_point());
117 
118     __TBB_ASSERT(m_thread_data != nullptr, nullptr);
119     __TBB_ASSERT(m_thread_data->my_post_resume_action == post_resume_action::none, nullptr);
120     __TBB_ASSERT(m_thread_data->my_post_resume_arg == nullptr, nullptr);
121     internal_suspend();
122 }
123 
124 bool task_dispatcher::resume(task_dispatcher& target) {
125     // Do not create non-trivial objects on the stack of this function. They might never be destroyed
126     {
127         thread_data* td = m_thread_data;
128         __TBB_ASSERT(&target != this, "We cannot resume to ourself");
129         __TBB_ASSERT(td != nullptr, "This task dispatcher must be attach to a thread data");
130         __TBB_ASSERT(td->my_task_dispatcher == this, "Thread data must be attached to this task dispatcher");
131 
132         // Change the task dispatcher
133         td->detach_task_dispatcher();
134         td->attach_task_dispatcher(target);
135     }
136     __TBB_ASSERT(m_suspend_point != nullptr, "Suspend point must be created");
137     __TBB_ASSERT(target.m_suspend_point != nullptr, "Suspend point must be created");
138     // Swap to the target coroutine.
139 
140     m_suspend_point->resume(target.m_suspend_point);
141     // Pay attention that m_thread_data can be changed after resume
142     if (m_thread_data) {
143         thread_data* td = m_thread_data;
144         __TBB_ASSERT(td != nullptr, "This task dispatcher must be attach to a thread data");
145         __TBB_ASSERT(td->my_task_dispatcher == this, "Thread data must be attached to this task dispatcher");
146         do_post_resume_action();
147 
148         // Remove the recall flag if the thread in its original task dispatcher
149         arena_slot* slot = td->my_arena_slot;
150         __TBB_ASSERT(slot != nullptr, nullptr);
151         if (this == slot->my_default_task_dispatcher) {
152             __TBB_ASSERT(m_suspend_point != nullptr, nullptr);
153             m_suspend_point->m_is_owner_recalled.store(false, std::memory_order_relaxed);
154         }
155         return true;
156     }
157     return false;
158 }
159 
160 void task_dispatcher::do_post_resume_action() {
161     thread_data* td = m_thread_data;
162     switch (td->my_post_resume_action) {
163     case post_resume_action::register_waiter:
164     {
165         __TBB_ASSERT(td->my_post_resume_arg, "The post resume action must have an argument");
166         static_cast<market_concurrent_monitor::resume_context*>(td->my_post_resume_arg)->notify();
167         break;
168     }
169     case post_resume_action::cleanup:
170     {
171         __TBB_ASSERT(td->my_post_resume_arg, "The post resume action must have an argument");
172         task_dispatcher* to_cleanup = static_cast<task_dispatcher*>(td->my_post_resume_arg);
173         // Release coroutine's reference to my_arena
174         td->my_arena->on_thread_leaving<arena::ref_external>();
175         // Cache the coroutine for possible later re-usage
176         td->my_arena->my_co_cache.push(to_cleanup);
177         break;
178     }
179     case post_resume_action::notify:
180     {
181         __TBB_ASSERT(td->my_post_resume_arg, "The post resume action must have an argument");
182         suspend_point_type* sp = static_cast<suspend_point_type*>(td->my_post_resume_arg);
183         sp->recall_owner();
184         // Do not access sp because it can be destroyed after recall
185 
186         auto is_our_suspend_point = [sp] (market_context ctx) {
187             return std::uintptr_t(sp) == ctx.my_uniq_addr;
188         };
189         td->my_arena->my_market->get_wait_list().notify(is_our_suspend_point);
190         break;
191     }
192     default:
193         __TBB_ASSERT(td->my_post_resume_action == post_resume_action::none, "Unknown post resume action");
194         __TBB_ASSERT(td->my_post_resume_arg == nullptr, "The post resume argument should not be set");
195     }
196     td->clear_post_resume_action();
197 }
198 
199 #else
200 
201 void suspend(suspend_callback_type, void*) {
202     __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform");
203 }
204 
205 void resume(suspend_point_type*) {
206     __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform");
207 }
208 
209 suspend_point_type* current_suspend_point() {
210     __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform");
211     return nullptr;
212 }
213 
214 #endif /* __TBB_RESUMABLE_TASKS */
215 
216 void notify_waiters(std::uintptr_t wait_ctx_addr) {
217     auto is_related_wait_ctx = [&] (market_context context) {
218         return wait_ctx_addr == context.my_uniq_addr;
219     };
220 
221     r1::governor::get_thread_data()->my_arena->my_market->get_wait_list().notify(is_related_wait_ctx);
222 }
223 
224 } // namespace r1
225 } // namespace detail
226 } // namespace tbb
227 
228