// Copyright (C) 2010  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_Hh_
#define DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_Hh_

#include <vector>
#include "../matrix.h"
#include "../statistics.h"
#include "cross_validate_regression_trainer_abstract.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename reg_funct_type,
        typename sample_type,
        typename label_type
        >
    matrix<double,1,2>
    test_regression_function (
        const reg_funct_type& reg_funct,
        const std::vector<sample_type>& x_test,
        const std::vector<label_type>& y_test
    )
    {

        // make sure requires clause is not broken
        DLIB_ASSERT( is_learning_problem(x_test,y_test) == true,
                    "\tmatrix test_regression_function()"
                    << "\n\t invalid inputs were given to this function"
                    << "\n\t is_learning_problem(x_test,y_test): " 
                    << is_learning_problem(x_test,y_test));

        running_stats<double> rs;
        running_scalar_covariance<double> rc;

        for (unsigned long i = 0; i < x_test.size(); ++i)
        {
            // compute error
            const double output = reg_funct(x_test[i]);
            const double temp = output - y_test[i];

            rs.add(temp*temp);
            rc.add(output, y_test[i]);
        }

        matrix<double,1,2> result;
        result = rs.mean(), std::pow(rc.correlation(),2);
        return result;
    }

// ----------------------------------------------------------------------------------------

    template <
        typename trainer_type,
        typename sample_type,
        typename label_type 
        >
    matrix<double,1,2> 
    cross_validate_regression_trainer (
        const trainer_type& trainer,
        const std::vector<sample_type>& x,
        const std::vector<label_type>& y,
        const long folds
    )
    {

        // make sure requires clause is not broken
        DLIB_ASSERT(is_learning_problem(x,y) == true &&
                    1 < folds && folds <= static_cast<long>(x.size()),
            "\tmatrix cross_validate_regression_trainer()"
            << "\n\t invalid inputs were given to this function"
            << "\n\t x.size(): " << x.size() 
            << "\n\t folds:  " << folds 
            << "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y)
            );



        const long num_in_test = x.size()/folds;
        const long num_in_train = x.size() - num_in_test;

        running_stats<double> rs;
        running_scalar_covariance<double> rc;

        std::vector<sample_type> x_test, x_train;
        std::vector<label_type> y_test, y_train;


        long next_test_idx = 0;


        for (long i = 0; i < folds; ++i)
        {
            x_test.clear();
            y_test.clear();
            x_train.clear();
            y_train.clear();

            // load up the test samples
            for (long cnt = 0; cnt < num_in_test; ++cnt)
            {
                x_test.push_back(x[next_test_idx]);
                y_test.push_back(y[next_test_idx]);
                next_test_idx = (next_test_idx + 1)%x.size();
            }

            // load up the training samples
            long next = next_test_idx;
            for (long cnt = 0; cnt < num_in_train; ++cnt)
            {
                x_train.push_back(x[next]);
                y_train.push_back(y[next]);
                next = (next + 1)%x.size();
            }


            try
            {
                const typename trainer_type::trained_function_type& df = trainer.train(x_train,y_train);

                // do the training and testing
                for (unsigned long j = 0; j < x_test.size(); ++j)
                {
                    // compute error
                    const double output = df(x_test[j]);
                    const double temp = output - y_test[j];

                    rs.add(temp*temp);
                    rc.add(output, y_test[j]);
                }
            }
            catch (invalid_nu_error&)
            {
                // just ignore cases which result in an invalid nu
            }

        } // for (long i = 0; i < folds; ++i)

        matrix<double,1,2> result;
        result = rs.mean(), std::pow(rc.correlation(),2);
        return result;
    }

}

// ----------------------------------------------------------------------------------------

#endif // DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_Hh_