// Copyright (C) 2011  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_
#define DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_

#include "cross_validate_assignment_trainer_abstract.h"
#include <vector>
#include "../matrix.h"
#include "svm.h"


namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename assignment_function
        >
    double test_assignment_function (
        const assignment_function& assigner,
        const std::vector<typename assignment_function::sample_type>& samples,
        const std::vector<typename assignment_function::label_type>& labels
    )
    {
        // make sure requires clause is not broken
#ifdef ENABLE_ASSERTS
        if (assigner.forces_assignment())
        {
            DLIB_ASSERT(is_forced_assignment_problem(samples, labels),
                "\t double test_assignment_function()"
                << "\n\t invalid inputs were given to this function"
                << "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
                << "\n\t is_assignment_problem(samples,labels):        " << is_assignment_problem(samples,labels)
                << "\n\t is_learning_problem(samples,labels):          " << is_learning_problem(samples,labels)
                );
        }
        else
        {
            DLIB_ASSERT(is_assignment_problem(samples, labels),
                "\t double test_assignment_function()"
                << "\n\t invalid inputs were given to this function"
                << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
                << "\n\t is_learning_problem(samples,labels):   " << is_learning_problem(samples,labels)
                );
        }
#endif
        double total_right = 0;
        double total = 0;
        for (unsigned long i = 0; i < samples.size(); ++i)
        {
            const std::vector<long>& out = assigner(samples[i]);
            for (unsigned long j = 0; j < out.size(); ++j)
            {
                if (out[j] == labels[i][j])
                    ++total_right;

                ++total;
            }
        }

        if (total != 0)
            return total_right/total;
        else
            return 1;
    }

// ----------------------------------------------------------------------------------------

    template <
        typename trainer_type
        >
    double cross_validate_assignment_trainer (
        const trainer_type& trainer,
        const std::vector<typename trainer_type::sample_type>& samples,
        const std::vector<typename trainer_type::label_type>& labels,
        const long folds
    )
    {
        // make sure requires clause is not broken
#ifdef ENABLE_ASSERTS
        if (trainer.forces_assignment())
        {
            DLIB_ASSERT(is_forced_assignment_problem(samples, labels) &&
                        1 < folds && folds <= static_cast<long>(samples.size()),
                "\t double cross_validate_assignment_trainer()"
                << "\n\t invalid inputs were given to this function"
                << "\n\t samples.size(): " << samples.size() 
                << "\n\t folds:  " << folds 
                << "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
                << "\n\t is_assignment_problem(samples,labels):        " << is_assignment_problem(samples,labels)
                << "\n\t is_learning_problem(samples,labels):          " << is_learning_problem(samples,labels)
                );
        }
        else
        {
            DLIB_ASSERT(is_assignment_problem(samples, labels) &&
                        1 < folds && folds <= static_cast<long>(samples.size()),
                "\t double cross_validate_assignment_trainer()"
                << "\n\t invalid inputs were given to this function"
                << "\n\t samples.size(): " << samples.size() 
                << "\n\t folds:  " << folds 
                << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
                << "\n\t is_learning_problem(samples,labels):   " << is_learning_problem(samples,labels)
                );
        }
#endif



        typedef typename trainer_type::sample_type sample_type;
        typedef typename trainer_type::label_type label_type;

        const long num_in_test  = samples.size()/folds;
        const long num_in_train = samples.size() - num_in_test;


        std::vector<sample_type> samples_test, samples_train;
        std::vector<label_type> labels_test, labels_train;


        long next_test_idx = 0;
        double total_right = 0;
        double total = 0;


        for (long i = 0; i < folds; ++i)
        {
            samples_test.clear();
            labels_test.clear();
            samples_train.clear();
            labels_train.clear();

            // load up the test samples
            for (long cnt = 0; cnt < num_in_test; ++cnt)
            {
                samples_test.push_back(samples[next_test_idx]);
                labels_test.push_back(labels[next_test_idx]);
                next_test_idx = (next_test_idx + 1)%samples.size();
            }

            // load up the training samples
            long next = next_test_idx;
            for (long cnt = 0; cnt < num_in_train; ++cnt)
            {
                samples_train.push_back(samples[next]);
                labels_train.push_back(labels[next]);
                next = (next + 1)%samples.size();
            }


            const typename trainer_type::trained_function_type& df = trainer.train(samples_train,labels_train);

            // check how good df is on the test data
            for (unsigned long i = 0; i < samples_test.size(); ++i)
            {
                const std::vector<long>& out = df(samples_test[i]);
                for (unsigned long j = 0; j < out.size(); ++j)
                {
                    if (out[j] == labels_test[i][j])
                        ++total_right;

                    ++total;
                }
            }

        } // for (long i = 0; i < folds; ++i)

        if (total != 0)
            return total_right/total;
        else
            return 1;

    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_