1 /* 2 Copyright (c) 2005-2021 Intel Corporation 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 // Do not include task.h directly. Use scheduler_common.h instead 18 #include "scheduler_common.h" 19 #include "governor.h" 20 #include "arena.h" 21 #include "thread_data.h" 22 #include "task_dispatcher.h" 23 #include "waiters.h" 24 #include "itt_notify.h" 25 26 #include "oneapi/tbb/detail/_task.h" 27 #include "oneapi/tbb/partitioner.h" 28 #include "oneapi/tbb/task.h" 29 30 #include <cstring> 31 32 namespace tbb { 33 namespace detail { 34 namespace r1 { 35 36 //------------------------------------------------------------------------ 37 // resumable tasks 38 //------------------------------------------------------------------------ 39 #if __TBB_RESUMABLE_TASKS 40 41 void suspend(suspend_callback_type suspend_callback, void* user_callback) { 42 thread_data& td = *governor::get_thread_data(); 43 td.my_task_dispatcher->suspend(suspend_callback, user_callback); 44 // Do not access td after suspend. 45 } 46 47 void resume(suspend_point_type* sp) { 48 assert_pointers_valid(sp, sp->m_arena); 49 task_dispatcher& task_disp = sp->m_resume_task.m_target; 50 __TBB_ASSERT(task_disp.m_thread_data == nullptr, nullptr); 51 52 // TODO: remove this work-around 53 // Prolong the arena's lifetime while all coroutines are alive 54 // (otherwise the arena can be destroyed while some tasks are suspended). 55 arena& a = *sp->m_arena; 56 a.my_references += arena::ref_external; 57 58 if (task_disp.m_properties.critical_task_allowed) { 59 // The target is not in the process of executing critical task, so the resume task is not critical. 60 a.my_resume_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random)); 61 } else { 62 #if __TBB_PREVIEW_CRITICAL_TASKS 63 // The target is in the process of executing critical task, so the resume task is critical. 64 a.my_critical_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random)); 65 #endif 66 } 67 68 // Do not access target after that point. 69 a.advertise_new_work<arena::wakeup>(); 70 71 // Release our reference to my_arena. 72 a.on_thread_leaving<arena::ref_external>(); 73 } 74 75 suspend_point_type* current_suspend_point() { 76 thread_data& td = *governor::get_thread_data(); 77 return td.my_task_dispatcher->get_suspend_point(); 78 } 79 80 static task_dispatcher& create_coroutine(thread_data& td) { 81 // We may have some task dispatchers cached 82 task_dispatcher* task_disp = td.my_arena->my_co_cache.pop(); 83 if (!task_disp) { 84 void* ptr = cache_aligned_allocate(sizeof(task_dispatcher)); 85 task_disp = new(ptr) task_dispatcher(td.my_arena); 86 task_disp->init_suspend_point(td.my_arena, td.my_arena->my_market->worker_stack_size()); 87 } 88 // Prolong the arena's lifetime until all coroutines is alive 89 // (otherwise the arena can be destroyed while some tasks are suspended). 90 // TODO: consider behavior if there are more than 4K external references. 91 td.my_arena->my_references += arena::ref_external; 92 return *task_disp; 93 } 94 95 void task_dispatcher::suspend(suspend_callback_type suspend_callback, void* user_callback) { 96 __TBB_ASSERT(suspend_callback != nullptr, nullptr); 97 __TBB_ASSERT(user_callback != nullptr, nullptr); 98 __TBB_ASSERT(m_thread_data != nullptr, nullptr); 99 100 arena_slot* slot = m_thread_data->my_arena_slot; 101 __TBB_ASSERT(slot != nullptr, nullptr); 102 103 task_dispatcher& default_task_disp = slot->default_task_dispatcher(); 104 // TODO: simplify the next line, e.g. is_task_dispatcher_recalled( task_dispatcher& ) 105 bool is_recalled = default_task_disp.get_suspend_point()->m_is_owner_recalled.load(std::memory_order_acquire); 106 task_dispatcher& target = is_recalled ? default_task_disp : create_coroutine(*m_thread_data); 107 108 thread_data::suspend_callback_wrapper callback = { suspend_callback, user_callback, get_suspend_point() }; 109 m_thread_data->set_post_resume_action(thread_data::post_resume_action::callback, &callback); 110 resume(target); 111 112 if (m_properties.outermost) { 113 recall_point(); 114 } 115 } 116 117 void task_dispatcher::resume(task_dispatcher& target) { 118 // Do not create non-trivial objects on the stack of this function. They might never be destroyed 119 { 120 thread_data* td = m_thread_data; 121 __TBB_ASSERT(&target != this, "We cannot resume to ourself"); 122 __TBB_ASSERT(td != nullptr, "This task dispatcher must be attach to a thread data"); 123 __TBB_ASSERT(td->my_task_dispatcher == this, "Thread data must be attached to this task dispatcher"); 124 __TBB_ASSERT(td->my_post_resume_action != thread_data::post_resume_action::none, "The post resume action must be set"); 125 __TBB_ASSERT(td->my_post_resume_arg, "The post resume action must have an argument"); 126 127 // Change the task dispatcher 128 td->detach_task_dispatcher(); 129 td->attach_task_dispatcher(target); 130 } 131 __TBB_ASSERT(m_suspend_point != nullptr, "Suspend point must be created"); 132 __TBB_ASSERT(target.m_suspend_point != nullptr, "Suspend point must be created"); 133 // Swap to the target coroutine. 134 m_suspend_point->m_co_context.resume(target.m_suspend_point->m_co_context); 135 // Pay attention that m_thread_data can be changed after resume 136 { 137 thread_data* td = m_thread_data; 138 __TBB_ASSERT(td != nullptr, "This task dispatcher must be attach to a thread data"); 139 __TBB_ASSERT(td->my_task_dispatcher == this, "Thread data must be attached to this task dispatcher"); 140 td->do_post_resume_action(); 141 142 // Remove the recall flag if the thread in its original task dispatcher 143 arena_slot* slot = td->my_arena_slot; 144 __TBB_ASSERT(slot != nullptr, nullptr); 145 if (this == slot->my_default_task_dispatcher) { 146 __TBB_ASSERT(m_suspend_point != nullptr, nullptr); 147 m_suspend_point->m_is_owner_recalled.store(false, std::memory_order_relaxed); 148 } 149 } 150 } 151 152 void thread_data::do_post_resume_action() { 153 __TBB_ASSERT(my_post_resume_action != thread_data::post_resume_action::none, "The post resume action must be set"); 154 __TBB_ASSERT(my_post_resume_arg, "The post resume action must have an argument"); 155 156 switch (my_post_resume_action) { 157 case post_resume_action::register_waiter: 158 { 159 static_cast<extended_concurrent_monitor::resume_context*>(my_post_resume_arg)->notify(); 160 break; 161 } 162 case post_resume_action::resume: 163 { 164 r1::resume(static_cast<suspend_point_type*>(my_post_resume_arg)); 165 break; 166 } 167 case post_resume_action::callback: 168 { 169 suspend_callback_wrapper callback = *static_cast<suspend_callback_wrapper*>(my_post_resume_arg); 170 callback(); 171 break; 172 } 173 case post_resume_action::cleanup: 174 { 175 task_dispatcher* to_cleanup = static_cast<task_dispatcher*>(my_post_resume_arg); 176 // Release coroutine's reference to my_arena. 177 my_arena->on_thread_leaving<arena::ref_external>(); 178 // Cache the coroutine for possible later re-usage 179 my_arena->my_co_cache.push(to_cleanup); 180 break; 181 } 182 case post_resume_action::notify: 183 { 184 std::atomic<bool>& owner_recall_flag = *static_cast<std::atomic<bool>*>(my_post_resume_arg); 185 owner_recall_flag.store(true, std::memory_order_release); 186 // Do not access recall_flag because it can be destroyed after the notification. 187 break; 188 } 189 default: 190 __TBB_ASSERT(false, "Unknown post resume action"); 191 } 192 193 my_post_resume_action = post_resume_action::none; 194 my_post_resume_arg = nullptr; 195 } 196 197 #else 198 199 void suspend(suspend_callback_type, void*) { 200 __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform"); 201 } 202 203 void resume(suspend_point_type*) { 204 __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform"); 205 } 206 207 suspend_point_type* current_suspend_point() { 208 __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform"); 209 return nullptr; 210 } 211 212 #endif /* __TBB_RESUMABLE_TASKS */ 213 214 void notify_waiters(std::uintptr_t wait_ctx_addr) { 215 auto is_related_wait_ctx = [&] (extended_context context) { 216 return wait_ctx_addr == context.my_uniq_addr; 217 }; 218 219 r1::governor::get_thread_data()->my_arena->my_market->get_wait_list().notify(is_related_wait_ctx); 220 } 221 222 } // namespace r1 223 } // namespace detail 224 } // namespace tbb 225 226