// Copyright (C) 2008  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#undef DLIB_RBf_NETWORK_ABSTRACT_
#ifdef DLIB_RBf_NETWORK_ABSTRACT_

#include "../algs.h"
#include "function_abstract.h"
#include "kernel_abstract.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename K 
        >
    class rbf_network_trainer 
    {
        /*!
            REQUIREMENTS ON K 
                is a kernel function object as defined in dlib/svm/kernel_abstract.h 
                (since this is supposed to be a RBF network it is probably reasonable
                to use some sort of radial basis kernel)

            INITIAL VALUE
                - get_num_centers() == 10 

            WHAT THIS OBJECT REPRESENTS
                This object implements a trainer for a radial basis function network.

                The implementation of this algorithm follows the normal RBF training 
                process.  For more details see the code or the Wikipedia article
                about RBF networks.  
        !*/

    public:
        typedef K kernel_type;
        typedef typename kernel_type::scalar_type scalar_type;
        typedef typename kernel_type::sample_type sample_type;
        typedef typename kernel_type::mem_manager_type mem_manager_type;
        typedef decision_function<kernel_type> trained_function_type;

        rbf_network_trainer (
        ); 
        /*!
            ensures
                - this object is properly initialized
        !*/

        void set_kernel (
            const kernel_type& k
        );
        /*!
            ensures
                - #get_kernel() == k 
        !*/

        const kernel_type& get_kernel (
        ) const;
        /*!
            ensures
                - returns a copy of the kernel function in use by this object
        !*/

        void set_num_centers (
            const unsigned long num_centers
        );
        /*!
            ensures
                - #get_num_centers() == num_centers
        !*/

        const unsigned long get_num_centers (
        ) const;
        /*!
            ensures
                - returns the maximum number of centers (a.k.a. basis_vectors in the 
                  trained decision_function) you will get when you train this object on data.
        !*/

        template <
            typename in_sample_vector_type,
            typename in_scalar_vector_type
            >
        const decision_function<kernel_type> train (
            const in_sample_vector_type& x,
            const in_scalar_vector_type& y
        ) const
        /*!
            requires
                - x == a matrix or something convertible to a matrix via mat().
                  Also, x should contain sample_type objects.
                - y == a matrix or something convertible to a matrix via mat().
                  Also, y should contain scalar_type objects.
                - is_learning_problem(x,y) == true
            ensures
                - trains a RBF network given the training samples in x and 
                  labels in y and returns the resulting decision_function
            throws
                - std::bad_alloc
        !*/

        void swap (
            rbf_network_trainer& item
        );
        /*!
            ensures
                - swaps *this and item
        !*/

    }; 

// ----------------------------------------------------------------------------------------

    template <typename K>
    void swap (
        rbf_network_trainer<K>& a,
        rbf_network_trainer<K>& b
    ) { a.swap(b); }
    /*!
        provides a global swap
    !*/

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_RBf_NETWORK_ABSTRACT_