// Copyright (C) 2010  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_SVM_C_EKm_TRAINER_Hh_
#define DLIB_SVM_C_EKm_TRAINER_Hh_

#include "../algs.h"
#include "function.h"
#include "kernel.h"
#include "empirical_kernel_map.h"
#include "svm_c_linear_trainer.h"
#include "svm_c_ekm_trainer_abstract.h"
#include "../statistics.h"
#include "../rand.h"
#include <vector>

namespace dlib
{
    template <
        typename K 
        >
    class svm_c_ekm_trainer
    {

    public:
        typedef K kernel_type;
        typedef typename kernel_type::scalar_type scalar_type;
        typedef typename kernel_type::sample_type sample_type;
        typedef typename kernel_type::mem_manager_type mem_manager_type;
        typedef decision_function<kernel_type> trained_function_type;

        svm_c_ekm_trainer (
        )
        {
            verbose = false;
            ekm_stale = true;

            initial_basis_size = 10;
            basis_size_increment = 50;
            max_basis_size = 300;
        }

        explicit svm_c_ekm_trainer (
            const scalar_type& C 
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(C > 0,
                "\t svm_c_ekm_trainer::svm_c_ekm_trainer()"
                << "\n\t C must be greater than 0"
                << "\n\t C:    " << C 
                << "\n\t this: " << this
                );


            ocas.set_c(C);
            verbose = false;
            ekm_stale = true;

            initial_basis_size = 10;
            basis_size_increment = 50;
            max_basis_size = 300;
        }

        void set_epsilon (
            scalar_type eps
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(eps > 0,
                "\t void svm_c_ekm_trainer::set_epsilon()"
                << "\n\t eps must be greater than 0"
                << "\n\t eps: " << eps 
                << "\n\t this: " << this
                );

            ocas.set_epsilon(eps);
        }

        const scalar_type get_epsilon (
        ) const
        {
            return ocas.get_epsilon();
        }

        void set_max_iterations (
            unsigned long max_iter
        )
        {
            ocas.set_max_iterations(max_iter);
        }

        unsigned long get_max_iterations (
        )
        {
            return ocas.get_max_iterations();
        }

        void be_verbose (
        ) 
        { 
            verbose = true;
            ocas.be_quiet(); 
        }

        void be_very_verbose (
        )
        {
            verbose = true;
            ocas.be_verbose(); 
        }

        void be_quiet (
        )
        { 
            verbose = false;
            ocas.be_quiet(); 
        }

        void set_oca (
            const oca& item
        )
        {
            ocas.set_oca(item);
        }

        const oca get_oca (
        ) const
        {
            return ocas.get_oca();
        }

        const kernel_type get_kernel (
        ) const
        {
            return kern;
        }

        void set_kernel (
            const kernel_type& k
        )
        {
            kern = k;
            ekm_stale = true;
        }

        template <typename T>
        void set_basis (
            const T& basis_samples
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(basis_samples.size() > 0 && is_vector(mat(basis_samples)),
                "\tvoid svm_c_ekm_trainer::set_basis(basis_samples)"
                << "\n\t You have to give a non-empty set of basis_samples and it must be a vector"
                << "\n\t basis_samples.size():                       " << basis_samples.size() 
                << "\n\t is_vector(mat(basis_samples)): " << is_vector(mat(basis_samples)) 
                << "\n\t this: " << this
                );

            basis = mat(basis_samples);
            ekm_stale = true;
        }

        bool basis_loaded(
        ) const
        {
            return (basis.size() != 0);
        }

        void clear_basis (
        )
        {
            basis.set_size(0);
            ekm.clear();
            ekm_stale = true;
        }

        unsigned long get_max_basis_size (
        ) const
        {
            return max_basis_size;
        }

        void set_max_basis_size (
            unsigned long max_basis_size_
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(max_basis_size_ > 0,
                "\t void svm_c_ekm_trainer::set_max_basis_size()"
                << "\n\t max_basis_size_ must be greater than 0"
                << "\n\t max_basis_size_: " << max_basis_size_ 
                << "\n\t this:            " << this
                );

            max_basis_size = max_basis_size_;
            if (initial_basis_size > max_basis_size)
                initial_basis_size = max_basis_size;
        }

        unsigned long get_initial_basis_size (
        ) const
        {
            return initial_basis_size;
        }

        void set_initial_basis_size (
            unsigned long initial_basis_size_
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(initial_basis_size_ > 0,
                "\t void svm_c_ekm_trainer::set_initial_basis_size()"
                << "\n\t initial_basis_size_ must be greater than 0"
                << "\n\t initial_basis_size_: " << initial_basis_size_ 
                << "\n\t this:                " << this
                );

            initial_basis_size = initial_basis_size_;

            if (initial_basis_size > max_basis_size)
                max_basis_size = initial_basis_size;
        }

        unsigned long get_basis_size_increment (
        ) const
        {
            return basis_size_increment;
        }

        void set_basis_size_increment (
            unsigned long basis_size_increment_
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(basis_size_increment_ > 0,
                "\t void svm_c_ekm_trainer::set_basis_size_increment()"
                << "\n\t basis_size_increment_ must be greater than 0"
                << "\n\t basis_size_increment_: " << basis_size_increment_ 
                << "\n\t this:                  " << this
                );

            basis_size_increment = basis_size_increment_;
        }

        void set_c (
            scalar_type C 
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(C > 0,
                "\t void svm_c_ekm_trainer::set_c()"
                << "\n\t C must be greater than 0"
                << "\n\t C:    " << C 
                << "\n\t this: " << this
                );

            ocas.set_c(C);
        }

        const scalar_type get_c_class1 (
        ) const
        {
            return ocas.get_c_class1();
        }

        const scalar_type get_c_class2 (
        ) const
        {
            return ocas.get_c_class2();
        }

        void set_c_class1 (
            scalar_type C
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(C > 0,
                "\t void svm_c_ekm_trainer::set_c_class1()"
                << "\n\t C must be greater than 0"
                << "\n\t C:    " << C 
                << "\n\t this: " << this
                );

            ocas.set_c_class1(C);
        }

        void set_c_class2 (
            scalar_type C
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(C > 0,
                "\t void svm_c_ekm_trainer::set_c_class2()"
                << "\n\t C must be greater than 0"
                << "\n\t C:    " << C 
                << "\n\t this: " << this
                );

            ocas.set_c_class2(C);
        }

        template <
            typename in_sample_vector_type,
            typename in_scalar_vector_type
            >
        const decision_function<kernel_type> train (
            const in_sample_vector_type& x,
            const in_scalar_vector_type& y
        ) const
        {
            scalar_type obj;
            if (basis_loaded())
                return do_train_user_basis(mat(x),mat(y),obj);
            else
                return do_train_auto_basis(mat(x),mat(y),obj);
        }

        template <
            typename in_sample_vector_type,
            typename in_scalar_vector_type
            >
        const decision_function<kernel_type> train (
            const in_sample_vector_type& x,
            const in_scalar_vector_type& y,
            scalar_type& svm_objective
        ) const
        {
            if (basis_loaded())
                return do_train_user_basis(mat(x),mat(y),svm_objective);
            else
                return do_train_auto_basis(mat(x),mat(y),svm_objective);
        }


    private:

        template <
            typename in_sample_vector_type,
            typename in_scalar_vector_type
            >
        const decision_function<kernel_type> do_train_user_basis (
            const in_sample_vector_type& x,
            const in_scalar_vector_type& y,
            scalar_type& svm_objective
        ) const
        /*!
            requires
                - basis_loaded() == true
            ensures
                - trains an SVM with the user supplied basis
        !*/
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
                "\t decision_function svm_c_ekm_trainer::train(x,y)"
                << "\n\t invalid inputs were given to this function"
                << "\n\t x.nr(): " << x.nr() 
                << "\n\t y.nr(): " << y.nr() 
                << "\n\t x.nc(): " << x.nc() 
                << "\n\t y.nc(): " << y.nc() 
                << "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
                );

            if (ekm_stale)
            {
                ekm.load(kern, basis);
                ekm_stale = false;
            }

            // project all the samples with the ekm
            running_stats<scalar_type> rs;
            std::vector<matrix<scalar_type,0,1, mem_manager_type> > proj_samples;
            proj_samples.reserve(x.size());
            for (long i = 0; i < x.size(); ++i)
            {
                if (verbose)
                {
                    scalar_type err;
                    proj_samples.push_back(ekm.project(x(i), err));
                    rs.add(err);
                }
                else
                {
                    proj_samples.push_back(ekm.project(x(i)));
                }
            }

            if (verbose)
            {
                std::cout << "\nMean EKM projection error:                  " << rs.mean() << std::endl;
                std::cout << "Standard deviation of EKM projection error: " << rs.stddev() << std::endl;
            }
            
            // now do the training
            decision_function<linear_kernel<matrix<scalar_type,0,1, mem_manager_type> > > df;
            df = ocas.train(proj_samples, y, svm_objective);

            if (verbose)
            {
                std::cout << "Final svm objective: " << svm_objective << std::endl;
            }

            decision_function<kernel_type> final_df;
            final_df = ekm.convert_to_decision_function(df.basis_vectors(0));
            final_df.b = df.b;
            return final_df;
        }

        template <
            typename in_sample_vector_type,
            typename in_scalar_vector_type
            >
        const decision_function<kernel_type> do_train_auto_basis (
            const in_sample_vector_type& x,
            const in_scalar_vector_type& y,
            scalar_type& svm_objective
        ) const
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
                "\t decision_function svm_c_ekm_trainer::train(x,y)"
                << "\n\t invalid inputs were given to this function"
                << "\n\t x.nr(): " << x.nr() 
                << "\n\t y.nr(): " << y.nr() 
                << "\n\t x.nc(): " << x.nc() 
                << "\n\t y.nc(): " << y.nc() 
                << "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
                );


            std::vector<matrix<scalar_type,0,1, mem_manager_type> > proj_samples(x.size());
            decision_function<linear_kernel<matrix<scalar_type,0,1, mem_manager_type> > > df;

            // we will use a linearly_independent_subset_finder to store our basis set. 
            linearly_independent_subset_finder<kernel_type> lisf(get_kernel(), max_basis_size);

            dlib::rand rnd;

            // first pick the initial basis set randomly
            for (unsigned long i = 0; i < 10*initial_basis_size && lisf.size() < initial_basis_size; ++i)
            {
                lisf.add(x(rnd.get_random_32bit_number()%x.size()));
            }

            ekm.load(lisf);

            // first project all samples into the span of the current basis 
            for (long i = 0; i < x.size(); ++i)
            {
                proj_samples[i] = ekm.project(x(i));
            }


            svm_c_linear_trainer<linear_kernel<matrix<scalar_type,0,1,mem_manager_type> > > trainer(ocas);

            const scalar_type min_epsilon = trainer.get_epsilon();
            // while we are determining what the basis set will be we are going to use a very
            // lose stopping condition.  We will tighten it back up before producing the
            // final decision_function.
            trainer.set_epsilon(0.2);

            scalar_type prev_svm_objective = std::numeric_limits<scalar_type>::max();

            empirical_kernel_map<kernel_type> prev_ekm;

            // This loop is where we try to generate a basis for SVM training.  We will
            // do this by repeatedly training the SVM and adding a few points which violate the
            // margin to the basis in each iteration.
            while (true)
            {
                // if the basis is already as big as it's going to get then just do the most
                // accurate training right now.  
                if (lisf.size() == max_basis_size)
                    trainer.set_epsilon(min_epsilon);

                while (true)
                {
                    // now do the training.  
                    df = trainer.train(proj_samples, y, svm_objective);

                    if (svm_objective < prev_svm_objective)
                        break;

                    // If the training didn't reduce the objective more than last time then
                    // try lowering the epsilon and doing it again.
                    if (trainer.get_epsilon() > min_epsilon)
                    {
                        trainer.set_epsilon(std::max(trainer.get_epsilon()*0.5, min_epsilon));
                        if (verbose)
                            std::cout << " *** Reducing epsilon to " << trainer.get_epsilon() << std::endl;
                    }
                    else
                        break;
                }

                if (verbose)
                {
                    std::cout << "svm objective: " << svm_objective << std::endl;
                    std::cout << "basis size: " << lisf.size() << std::endl;
                }

                // if we failed to make progress on this iteration then we are done
                if (svm_objective >= prev_svm_objective)
                    break;

                prev_svm_objective = svm_objective;

                // now add more elements to the basis
                unsigned long count = 0;
                for (unsigned long j = 0; 
                     (j < 100*basis_size_increment) && (count < basis_size_increment) && (lisf.size() < max_basis_size); 
                     ++j)
                {
                    // pick a random sample
                    const unsigned long idx = rnd.get_random_32bit_number()%x.size();
                    // If it is a margin violator then it is useful to add it into the basis set.
                    if (df(proj_samples[idx])*y(idx) < 1)
                    {
                        // Add the sample into the basis set if it is linearly independent of all the
                        // vectors already in the basis set.  
                        if (lisf.add(x(idx)))
                        {
                            ++count;
                        }
                    }
                }
                // if we couldn't add any more basis vectors then stop
                if (count == 0)
                {
                    if (verbose)
                        std::cout << "Stopping, couldn't add more basis vectors." << std::endl;
                    break;
                }


                // Project all the samples into the span of our newly enlarged basis.  We will do this
                // using the special transformation in the EKM that lets us project from a smaller
                // basis set to a larger without needing to reevaluate kernel functions we have already
                // computed.
                ekm.swap(prev_ekm);
                ekm.load(lisf);
                projection_function<kernel_type> proj_part;
                matrix<double> prev_to_new;
                prev_ekm.get_transformation_to(ekm, prev_to_new, proj_part);

                
                matrix<scalar_type,0,1, mem_manager_type> temp;
                for (long i = 0; i < x.size(); ++i)
                {
                    // assign to temporary to avoid memory allocation that would result if we
                    // assigned this expression straight into proj_samples[i]
                    temp = prev_to_new*proj_samples[i] + proj_part(x(i));
                    proj_samples[i] = temp;

                }
            }
            
            // Reproject all the data samples using the final basis.  We could just use what we 
            // already have but the recursive thing done above to compute the proj_samples 
            // might have accumulated a little numerical error.  So lets just be safe.
            running_stats<scalar_type> rs, rs_margin;
            for (long i = 0; i < x.size(); ++i)
            {
                if (verbose)
                {
                    scalar_type err;
                    proj_samples[i] = ekm.project(x(i),err);
                    rs.add(err);
                    // if this point is within the margin 
                    if (df(proj_samples[i])*y(i) < 1)
                        rs_margin.add(err);
                }
                else
                {
                    proj_samples[i] = ekm.project(x(i));
                }
            }

            // do the final training
            trainer.set_epsilon(min_epsilon);
            df = trainer.train(proj_samples, y, svm_objective);


            if (verbose)
            {
                std::cout << "\nMean EKM projection error:                  " << rs.mean() << std::endl;
                std::cout << "Standard deviation of EKM projection error: " << rs.stddev() << std::endl;
                std::cout << "Mean EKM projection error for margin violators:                  " << rs_margin.mean() << std::endl;
                std::cout << "Standard deviation of EKM projection error for margin violators: " << ((rs_margin.current_n()>1)?rs_margin.stddev():0) << std::endl;

                std::cout << "Final svm objective: " << svm_objective << std::endl;
            }


            decision_function<kernel_type> final_df;
            final_df = ekm.convert_to_decision_function(df.basis_vectors(0));
            final_df.b = df.b;

            // we don't need the ekm anymore so clear it out
            ekm.clear();

            return final_df;
        }




        /*!
            CONVENTION
                - if (ekm_stale) then
                    - kern or basis have changed since the last time
                      they were loaded into the ekm
        !*/

        svm_c_linear_trainer<linear_kernel<matrix<scalar_type,0,1,mem_manager_type> > > ocas;
        bool verbose;

        kernel_type kern;
        unsigned long max_basis_size;
        unsigned long basis_size_increment;
        unsigned long initial_basis_size;


        matrix<sample_type,0,1,mem_manager_type> basis;
        mutable empirical_kernel_map<kernel_type> ekm;
        mutable bool ekm_stale; 

    }; 

}

#endif // DLIB_SVM_C_EKm_TRAINER_Hh_