// Copyright (C) 2008  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_THREAD_POOl_CPPh_
#define DLIB_THREAD_POOl_CPPh_ 

#include "thread_pool_extension.h"
#include <memory>

namespace dlib
{

// ----------------------------------------------------------------------------------------

    thread_pool_implementation::
    thread_pool_implementation (
        unsigned long num_threads
    ) : 
        task_done_signaler(m),
        task_ready_signaler(m),
        we_are_destructing(false)
    {
        tasks.resize(num_threads);
        threads.resize(num_threads);
        for (unsigned long i = 0; i < num_threads; ++i)
        {
            threads[i] = std::thread([&](){this->thread();});
        }
    }

// ----------------------------------------------------------------------------------------

    void thread_pool_implementation::
    shutdown_pool (
    )
    {
        {
            auto_mutex M(m);
            
            // first wait for all pending tasks to finish
            bool found_task = true;
            while (found_task)
            {
                found_task = false;
                for (unsigned long i = 0; i < tasks.size(); ++i)
                {
                    // If task bucket i has a task that is currently supposed to be processed
                    if (tasks[i].is_empty() == false)
                    {
                        found_task = true;
                        break;
                    }
                }

                if (found_task)
                    task_done_signaler.wait();
            }

            // now tell the threads to kill themselves
            we_are_destructing = true;
            task_ready_signaler.broadcast();
        }

        // wait for all threads to terminate
        for (auto& t : threads)
            t.join();
        threads.clear();

        // Throw any unhandled exceptions.  Since shutdown_pool() is only called in the
        // destructor this will kill the program.
        for (auto&& task : tasks)
            task.propagate_exception();
    }

// ----------------------------------------------------------------------------------------

    thread_pool_implementation::
    ~thread_pool_implementation()
    {
        shutdown_pool();
    }

// ----------------------------------------------------------------------------------------

    unsigned long thread_pool_implementation::
    num_threads_in_pool (
    ) const
    {
        auto_mutex M(m);
        return tasks.size();
    }

// ----------------------------------------------------------------------------------------

    void thread_pool_implementation::
    wait_for_task (
        uint64 task_id
    ) const
    {
        auto_mutex M(m);
        if (tasks.size() != 0)
        {
            const unsigned long idx = task_id_to_index(task_id);
            while (tasks[idx].task_id == task_id)
                task_done_signaler.wait();

            for (auto&& task : tasks)
                task.propagate_exception();
        }
    }

// ----------------------------------------------------------------------------------------

    void thread_pool_implementation::
    wait_for_all_tasks (
    ) const
    {
        const thread_id_type thread_id = get_thread_id();

        auto_mutex M(m);
        bool found_task = true;
        while (found_task)
        {
            found_task = false;
            for (unsigned long i = 0; i < tasks.size(); ++i)
            {
                // If task bucket i has a task that is currently supposed to be processed
                // and it originated from the calling thread
                if (tasks[i].is_empty() == false && tasks[i].thread_id == thread_id)
                {
                    found_task = true;
                    break;
                }
            }

            if (found_task)
                task_done_signaler.wait();
        }

        // throw any exceptions generated by the tasks
        for (auto&& task : tasks)
            task.propagate_exception();
    }

// ----------------------------------------------------------------------------------------

    bool thread_pool_implementation::
    is_worker_thread (
        const thread_id_type id
    ) const
    {
        for (unsigned long i = 0; i < worker_thread_ids.size(); ++i)
        {
            if (worker_thread_ids[i] == id)
                return true;
        }

        // if there aren't any threads in the pool then we consider all threads
        // to be worker threads
        if (tasks.size() == 0)
            return true;
        else
            return false;
    }

// ----------------------------------------------------------------------------------------

    void thread_pool_implementation::
    thread (
    )
    {
        {
            // save the id of this worker thread into worker_thread_ids
            auto_mutex M(m);
            thread_id_type id = get_thread_id();
            worker_thread_ids.push_back(id);
        }

        task_state_type task;
        while (we_are_destructing == false)
        {
            long idx = 0;

            // wait for a task to do 
            { auto_mutex M(m);
                while ( (idx = find_ready_task()) == -1 && we_are_destructing == false)
                    task_ready_signaler.wait();

                if (we_are_destructing)
                    break;

                tasks[idx].is_being_processed = true;
                task = tasks[idx];
            }

            std::exception_ptr eptr = nullptr;
            try
            {
                // now do the task
                if (task.bfp)
                    task.bfp();
                else if (task.mfp0)
                    task.mfp0();
                else if (task.mfp1)
                    task.mfp1(task.arg1);
                else if (task.mfp2)
                    task.mfp2(task.arg1, task.arg2);
            }
            catch(...)
            {
                eptr = std::current_exception();
            }

            // Now let others know that we finished the task.  We do this
            // by clearing out the state of this task
            { auto_mutex M(m);
                tasks[idx].is_being_processed = false;
                tasks[idx].task_id = 0;
                tasks[idx].bfp.clear();
                tasks[idx].mfp0.clear();
                tasks[idx].mfp1.clear();
                tasks[idx].mfp2.clear();
                tasks[idx].arg1 = 0;
                tasks[idx].arg2 = 0;
                tasks[idx].eptr = eptr;
                task_done_signaler.broadcast();
            }

        }
    }

// ----------------------------------------------------------------------------------------

    long thread_pool_implementation::
    find_empty_task_slot (
    ) const
    {
        for (auto&& task : tasks)
            task.propagate_exception();

        for (unsigned long i = 0; i < tasks.size(); ++i)
        {
            if (tasks[i].is_empty())
                return i;
        }

        return -1;
    }

// ----------------------------------------------------------------------------------------

    long thread_pool_implementation::
    find_ready_task (
    ) const
    {
        for (unsigned long i = 0; i < tasks.size(); ++i)
        {
            if (tasks[i].is_ready())
                return i;
        }

        return -1;
    }

// ----------------------------------------------------------------------------------------

    uint64 thread_pool_implementation::
    make_next_task_id (
        long idx
    )
    {
        uint64 id = tasks[idx].next_task_id * tasks.size() + idx;
        tasks[idx].next_task_id += 1;
        return id;
    }

// ----------------------------------------------------------------------------------------

    unsigned long thread_pool_implementation::
    task_id_to_index (
        uint64 id
    ) const
    {
        return static_cast<unsigned long>(id%tasks.size());
    }

// ----------------------------------------------------------------------------------------

    uint64 thread_pool_implementation::
    add_task_internal (
        const bfp_type& bfp,
        std::shared_ptr<function_object_copy>& item
    )
    {
        auto_mutex M(m);
        const thread_id_type my_thread_id = get_thread_id();

        // find a thread that isn't doing anything
        long idx = find_empty_task_slot();
        if (idx == -1 && is_worker_thread(my_thread_id))
        {
            // this function is being called from within a worker thread and there
            // aren't any other worker threads free so just perform the task right
            // here

            M.unlock();
            bfp();

            // return a task id that is both non-zero and also one
            // that is never normally returned.  This way calls
            // to wait_for_task() will never block given this id.
            return 1;
        }

        // wait until there is a thread that isn't doing anything
        while (idx == -1)
        {
            task_done_signaler.wait();
            idx = find_empty_task_slot();
        }

        tasks[idx].thread_id = my_thread_id;
        tasks[idx].task_id = make_next_task_id(idx);
        tasks[idx].bfp = bfp;
        tasks[idx].function_copy.swap(item);

        task_ready_signaler.signal();

        return tasks[idx].task_id;
    }

// ----------------------------------------------------------------------------------------

    bool thread_pool_implementation::
    is_task_thread (
    ) const
    {
        auto_mutex M(m);
        return is_worker_thread(get_thread_id());
    }

// ----------------------------------------------------------------------------------------

}


#endif // DLIB_THREAD_POOl_CPPh_