// Copyright (C) 2008 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_THREAD_POOl_CPPh_ #define DLIB_THREAD_POOl_CPPh_ #include "thread_pool_extension.h" #include <memory> namespace dlib { // ---------------------------------------------------------------------------------------- thread_pool_implementation:: thread_pool_implementation ( unsigned long num_threads ) : task_done_signaler(m), task_ready_signaler(m), we_are_destructing(false) { tasks.resize(num_threads); threads.resize(num_threads); for (unsigned long i = 0; i < num_threads; ++i) { threads[i] = std::thread([&](){this->thread();}); } } // ---------------------------------------------------------------------------------------- void thread_pool_implementation:: shutdown_pool ( ) { { auto_mutex M(m); // first wait for all pending tasks to finish bool found_task = true; while (found_task) { found_task = false; for (unsigned long i = 0; i < tasks.size(); ++i) { // If task bucket i has a task that is currently supposed to be processed if (tasks[i].is_empty() == false) { found_task = true; break; } } if (found_task) task_done_signaler.wait(); } // now tell the threads to kill themselves we_are_destructing = true; task_ready_signaler.broadcast(); } // wait for all threads to terminate for (auto& t : threads) t.join(); threads.clear(); // Throw any unhandled exceptions. Since shutdown_pool() is only called in the // destructor this will kill the program. for (auto&& task : tasks) task.propagate_exception(); } // ---------------------------------------------------------------------------------------- thread_pool_implementation:: ~thread_pool_implementation() { shutdown_pool(); } // ---------------------------------------------------------------------------------------- unsigned long thread_pool_implementation:: num_threads_in_pool ( ) const { auto_mutex M(m); return tasks.size(); } // ---------------------------------------------------------------------------------------- void thread_pool_implementation:: wait_for_task ( uint64 task_id ) const { auto_mutex M(m); if (tasks.size() != 0) { const unsigned long idx = task_id_to_index(task_id); while (tasks[idx].task_id == task_id) task_done_signaler.wait(); for (auto&& task : tasks) task.propagate_exception(); } } // ---------------------------------------------------------------------------------------- void thread_pool_implementation:: wait_for_all_tasks ( ) const { const thread_id_type thread_id = get_thread_id(); auto_mutex M(m); bool found_task = true; while (found_task) { found_task = false; for (unsigned long i = 0; i < tasks.size(); ++i) { // If task bucket i has a task that is currently supposed to be processed // and it originated from the calling thread if (tasks[i].is_empty() == false && tasks[i].thread_id == thread_id) { found_task = true; break; } } if (found_task) task_done_signaler.wait(); } // throw any exceptions generated by the tasks for (auto&& task : tasks) task.propagate_exception(); } // ---------------------------------------------------------------------------------------- bool thread_pool_implementation:: is_worker_thread ( const thread_id_type id ) const { for (unsigned long i = 0; i < worker_thread_ids.size(); ++i) { if (worker_thread_ids[i] == id) return true; } // if there aren't any threads in the pool then we consider all threads // to be worker threads if (tasks.size() == 0) return true; else return false; } // ---------------------------------------------------------------------------------------- void thread_pool_implementation:: thread ( ) { { // save the id of this worker thread into worker_thread_ids auto_mutex M(m); thread_id_type id = get_thread_id(); worker_thread_ids.push_back(id); } task_state_type task; while (we_are_destructing == false) { long idx = 0; // wait for a task to do { auto_mutex M(m); while ( (idx = find_ready_task()) == -1 && we_are_destructing == false) task_ready_signaler.wait(); if (we_are_destructing) break; tasks[idx].is_being_processed = true; task = tasks[idx]; } std::exception_ptr eptr = nullptr; try { // now do the task if (task.bfp) task.bfp(); else if (task.mfp0) task.mfp0(); else if (task.mfp1) task.mfp1(task.arg1); else if (task.mfp2) task.mfp2(task.arg1, task.arg2); } catch(...) { eptr = std::current_exception(); } // Now let others know that we finished the task. We do this // by clearing out the state of this task { auto_mutex M(m); tasks[idx].is_being_processed = false; tasks[idx].task_id = 0; tasks[idx].bfp.clear(); tasks[idx].mfp0.clear(); tasks[idx].mfp1.clear(); tasks[idx].mfp2.clear(); tasks[idx].arg1 = 0; tasks[idx].arg2 = 0; tasks[idx].eptr = eptr; task_done_signaler.broadcast(); } } } // ---------------------------------------------------------------------------------------- long thread_pool_implementation:: find_empty_task_slot ( ) const { for (auto&& task : tasks) task.propagate_exception(); for (unsigned long i = 0; i < tasks.size(); ++i) { if (tasks[i].is_empty()) return i; } return -1; } // ---------------------------------------------------------------------------------------- long thread_pool_implementation:: find_ready_task ( ) const { for (unsigned long i = 0; i < tasks.size(); ++i) { if (tasks[i].is_ready()) return i; } return -1; } // ---------------------------------------------------------------------------------------- uint64 thread_pool_implementation:: make_next_task_id ( long idx ) { uint64 id = tasks[idx].next_task_id * tasks.size() + idx; tasks[idx].next_task_id += 1; return id; } // ---------------------------------------------------------------------------------------- unsigned long thread_pool_implementation:: task_id_to_index ( uint64 id ) const { return static_cast<unsigned long>(id%tasks.size()); } // ---------------------------------------------------------------------------------------- uint64 thread_pool_implementation:: add_task_internal ( const bfp_type& bfp, std::shared_ptr<function_object_copy>& item ) { auto_mutex M(m); const thread_id_type my_thread_id = get_thread_id(); // find a thread that isn't doing anything long idx = find_empty_task_slot(); if (idx == -1 && is_worker_thread(my_thread_id)) { // this function is being called from within a worker thread and there // aren't any other worker threads free so just perform the task right // here M.unlock(); bfp(); // return a task id that is both non-zero and also one // that is never normally returned. This way calls // to wait_for_task() will never block given this id. return 1; } // wait until there is a thread that isn't doing anything while (idx == -1) { task_done_signaler.wait(); idx = find_empty_task_slot(); } tasks[idx].thread_id = my_thread_id; tasks[idx].task_id = make_next_task_id(idx); tasks[idx].bfp = bfp; tasks[idx].function_copy.swap(item); task_ready_signaler.signal(); return tasks[idx].task_id; } // ---------------------------------------------------------------------------------------- bool thread_pool_implementation:: is_task_thread ( ) const { auto_mutex M(m); return is_worker_thread(get_thread_id()); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_THREAD_POOl_CPPh_