// Copyright (C) 2008 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_RBf_NETWORK_ABSTRACT_ #ifdef DLIB_RBf_NETWORK_ABSTRACT_ #include "../algs.h" #include "function_abstract.h" #include "kernel_abstract.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename K > class rbf_network_trainer { /*! REQUIREMENTS ON K is a kernel function object as defined in dlib/svm/kernel_abstract.h (since this is supposed to be a RBF network it is probably reasonable to use some sort of radial basis kernel) INITIAL VALUE - get_num_centers() == 10 WHAT THIS OBJECT REPRESENTS This object implements a trainer for a radial basis function network. The implementation of this algorithm follows the normal RBF training process. For more details see the code or the Wikipedia article about RBF networks. !*/ public: typedef K kernel_type; typedef typename kernel_type::scalar_type scalar_type; typedef typename kernel_type::sample_type sample_type; typedef typename kernel_type::mem_manager_type mem_manager_type; typedef decision_function<kernel_type> trained_function_type; rbf_network_trainer ( ); /*! ensures - this object is properly initialized !*/ void set_kernel ( const kernel_type& k ); /*! ensures - #get_kernel() == k !*/ const kernel_type& get_kernel ( ) const; /*! ensures - returns a copy of the kernel function in use by this object !*/ void set_num_centers ( const unsigned long num_centers ); /*! ensures - #get_num_centers() == num_centers !*/ const unsigned long get_num_centers ( ) const; /*! ensures - returns the maximum number of centers (a.k.a. basis_vectors in the trained decision_function) you will get when you train this object on data. !*/ template < typename in_sample_vector_type, typename in_scalar_vector_type > const decision_function<kernel_type> train ( const in_sample_vector_type& x, const in_scalar_vector_type& y ) const /*! requires - x == a matrix or something convertible to a matrix via mat(). Also, x should contain sample_type objects. - y == a matrix or something convertible to a matrix via mat(). Also, y should contain scalar_type objects. - is_learning_problem(x,y) == true ensures - trains a RBF network given the training samples in x and labels in y and returns the resulting decision_function throws - std::bad_alloc !*/ void swap ( rbf_network_trainer& item ); /*! ensures - swaps *this and item !*/ }; // ---------------------------------------------------------------------------------------- template <typename K> void swap ( rbf_network_trainer<K>& a, rbf_network_trainer<K>& b ) { a.swap(b); } /*! provides a global swap !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_RBf_NETWORK_ABSTRACT_