1 /* 2 Copyright (c) 2023 Intel Corporation 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 #ifndef __TBB_task_emulation_layer_H 18 #define __TBB_task_emulation_layer_H 19 20 #include "tbb/task_group.h" 21 #include "tbb/task_arena.h" 22 23 #include <atomic> 24 25 namespace task_emulation { 26 27 struct task_group_pool { 28 task_group_pool() : pool_size(std::thread::hardware_concurrency()), task_submitters(new tbb::task_group[pool_size]) {} 29 30 ~task_group_pool() { 31 for (std::size_t i = 0; i < pool_size; ++i) { 32 task_submitters[i].wait(); 33 } 34 35 delete [] task_submitters; 36 } 37 38 tbb::task_group& operator[] (std::size_t idx) { return task_submitters[idx]; } 39 40 const std::size_t pool_size; 41 tbb::task_group* task_submitters; 42 }; 43 44 static task_group_pool tg_pool; 45 46 class base_task { 47 public: 48 base_task() = default; 49 50 base_task(const base_task& t) : m_type(t.m_type), m_parent(t.m_parent), m_child_counter(t.m_child_counter.load()) 51 {} 52 53 virtual ~base_task() = default; 54 55 void operator() () const { 56 task_type type_snapshot = m_type; 57 58 base_task* bypass = const_cast<base_task*>(this)->execute(); 59 60 if (m_parent && m_type != task_type::recycled) { 61 if (m_parent->remove_child_reference() == 0) { 62 m_parent->operator()(); 63 } 64 } 65 66 if (m_type == task_type::allocated) { 67 delete this; 68 } 69 70 if (bypass != nullptr) { 71 m_type = type_snapshot; 72 73 // Bypass is not supported by task_emulation and next_task executed directly. 74 // However, the old-TBB bypass behavior can be achieved with 75 // `return task_group::defer()` (check Migration Guide). 76 // Consider submit another task if recursion call is not acceptable 77 // i.e. instead of Direct Body call 78 // submit task_emulation::run_task(); 79 bypass->operator()(); 80 } 81 } 82 83 virtual base_task* execute() = 0; 84 85 template <typename C, typename... Args> 86 C* allocate_continuation(std::uint64_t ref, Args&&... args) { 87 C* continuation = new C{std::forward<Args>(args)...}; 88 continuation->m_type = task_type::allocated; 89 continuation->reset_parent(reset_parent()); 90 continuation->m_child_counter = ref; 91 return continuation; 92 } 93 94 template <typename F, typename... Args> 95 F create_child(Args&&... args) { 96 return create_child_impl<F>(std::forward<Args>(args)...); 97 } 98 99 template <typename F, typename... Args> 100 F create_child_and_increment(Args&&... args) { 101 add_child_reference(); 102 return create_child_impl<F>(std::forward<Args>(args)...); 103 } 104 105 template <typename F, typename... Args> 106 F* allocate_child(Args&&... args) { 107 return allocate_child_impl<F>(std::forward<Args>(args)...); 108 } 109 110 template <typename F, typename... Args> 111 F* allocate_child_and_increment(Args&&... args) { 112 add_child_reference(); 113 return allocate_child_impl<F>(std::forward<Args>(args)...); 114 } 115 116 template <typename C> 117 void recycle_as_child_of(C& c) { 118 m_type = task_type::recycled; 119 reset_parent(&c); 120 } 121 122 void recycle_as_continuation() { 123 m_type = task_type::recycled; 124 } 125 126 void add_child_reference() { 127 ++m_child_counter; 128 } 129 130 std::uint64_t remove_child_reference() { 131 return --m_child_counter; 132 } 133 134 protected: 135 enum class task_type { 136 stack_based, 137 allocated, 138 recycled 139 }; 140 141 mutable task_type m_type; 142 143 private: 144 template <typename F, typename... Args> 145 friend F create_root_task(tbb::task_group& tg, Args&&... args); 146 147 template <typename F, typename... Args> 148 friend F* allocate_root_task(tbb::task_group& tg, Args&&... args); 149 150 template <typename F, typename... Args> 151 F create_child_impl(Args&&... args) { 152 F obj{std::forward<Args>(args)...}; 153 obj.m_type = task_type::stack_based; 154 obj.reset_parent(this); 155 return obj; 156 } 157 158 template <typename F, typename... Args> 159 F* allocate_child_impl(Args&&... args) { 160 F* obj = new F{std::forward<Args>(args)...}; 161 obj->m_type = task_type::allocated; 162 obj->reset_parent(this); 163 return obj; 164 } 165 166 base_task* reset_parent(base_task* ptr = nullptr) { 167 auto p = m_parent; 168 m_parent = ptr; 169 return p; 170 } 171 172 base_task* m_parent{nullptr}; 173 std::atomic<std::uint64_t> m_child_counter{0}; 174 }; 175 176 class root_task : public base_task { 177 public: 178 root_task(tbb::task_group& tg) : m_tg(tg), m_callback(m_tg.defer([] { /* Create empty callback to preserve reference for wait. */})) { 179 add_child_reference(); 180 m_type = base_task::task_type::allocated; 181 } 182 183 private: 184 base_task* execute() override { 185 m_tg.run(std::move(m_callback)); 186 return nullptr; 187 } 188 189 tbb::task_group& m_tg; 190 tbb::task_handle m_callback; 191 }; 192 193 template <typename F, typename... Args> 194 F create_root_task(tbb::task_group& tg, Args&&... args) { 195 F obj{std::forward<Args>(args)...}; 196 obj.m_type = base_task::task_type::stack_based; 197 obj.reset_parent(new root_task{tg}); 198 return obj; 199 } 200 201 template <typename F, typename... Args> 202 F* allocate_root_task(tbb::task_group& tg, Args&&... args) { 203 F* obj = new F{std::forward<Args>(args)...}; 204 obj->m_type = base_task::task_type::allocated; 205 obj->reset_parent(new root_task{tg}); 206 return obj; 207 } 208 209 template <typename F> 210 void run_task(F&& f) { 211 tg_pool[tbb::this_task_arena::current_thread_index()].run(std::forward<F>(f)); 212 } 213 214 template <typename F> 215 void run_task(F* f) { 216 tg_pool[tbb::this_task_arena::current_thread_index()].run(std::ref(*f)); 217 } 218 219 template <typename F> 220 void run_and_wait(tbb::task_group& tg, F* f) { 221 tg.run_and_wait(std::ref(*f)); 222 } 223 } // namespace task_emulation 224 225 #endif // __TBB_task_emulation_layer_H 226