// Copyright (C) 2009 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_SVm_SPARSE_KERNEL #define DLIB_SVm_SPARSE_KERNEL #include "sparse_kernel_abstract.h" #include <cmath> #include <limits> #include "../algs.h" #include "../serialize.h" #include "sparse_vector.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename T > struct sparse_radial_basis_kernel { typedef typename T::value_type::second_type scalar_type; typedef T sample_type; typedef default_memory_manager mem_manager_type; sparse_radial_basis_kernel(const scalar_type g) : gamma(g) {} sparse_radial_basis_kernel() : gamma(0.1) {} sparse_radial_basis_kernel( const sparse_radial_basis_kernel& k ) : gamma(k.gamma) {} const scalar_type gamma; scalar_type operator() ( const sample_type& a, const sample_type& b ) const { const scalar_type d = distance_squared(a,b); return std::exp(-gamma*d); } sparse_radial_basis_kernel& operator= ( const sparse_radial_basis_kernel& k ) { const_cast<scalar_type&>(gamma) = k.gamma; return *this; } bool operator== ( const sparse_radial_basis_kernel& k ) const { return gamma == k.gamma; } }; template < typename T > void serialize ( const sparse_radial_basis_kernel<T>& item, std::ostream& out ) { try { serialize(item.gamma, out); } catch (serialization_error& e) { throw serialization_error(e.info + "\n while serializing object of type sparse_radial_basis_kernel"); } } template < typename T > void deserialize ( sparse_radial_basis_kernel<T>& item, std::istream& in ) { typedef typename T::value_type::second_type scalar_type; try { deserialize(const_cast<scalar_type&>(item.gamma), in); } catch (serialization_error& e) { throw serialization_error(e.info + "\n while deserializing object of type sparse_radial_basis_kernel"); } } // ---------------------------------------------------------------------------------------- template < typename T > struct sparse_polynomial_kernel { typedef typename T::value_type::second_type scalar_type; typedef T sample_type; typedef default_memory_manager mem_manager_type; sparse_polynomial_kernel(const scalar_type g, const scalar_type c, const scalar_type d) : gamma(g), coef(c), degree(d) {} sparse_polynomial_kernel() : gamma(1), coef(0), degree(1) {} sparse_polynomial_kernel( const sparse_polynomial_kernel& k ) : gamma(k.gamma), coef(k.coef), degree(k.degree) {} typedef T type; const scalar_type gamma; const scalar_type coef; const scalar_type degree; scalar_type operator() ( const sample_type& a, const sample_type& b ) const { return std::pow(gamma*(dot(a,b)) + coef, degree); } sparse_polynomial_kernel& operator= ( const sparse_polynomial_kernel& k ) { const_cast<scalar_type&>(gamma) = k.gamma; const_cast<scalar_type&>(coef) = k.coef; const_cast<scalar_type&>(degree) = k.degree; return *this; } bool operator== ( const sparse_polynomial_kernel& k ) const { return (gamma == k.gamma) && (coef == k.coef) && (degree == k.degree); } }; template < typename T > void serialize ( const sparse_polynomial_kernel<T>& item, std::ostream& out ) { try { serialize(item.gamma, out); serialize(item.coef, out); serialize(item.degree, out); } catch (serialization_error& e) { throw serialization_error(e.info + "\n while serializing object of type sparse_polynomial_kernel"); } } template < typename T > void deserialize ( sparse_polynomial_kernel<T>& item, std::istream& in ) { typedef typename T::value_type::second_type scalar_type; try { deserialize(const_cast<scalar_type&>(item.gamma), in); deserialize(const_cast<scalar_type&>(item.coef), in); deserialize(const_cast<scalar_type&>(item.degree), in); } catch (serialization_error& e) { throw serialization_error(e.info + "\n while deserializing object of type sparse_polynomial_kernel"); } } // ---------------------------------------------------------------------------------------- template < typename T > struct sparse_sigmoid_kernel { typedef typename T::value_type::second_type scalar_type; typedef T sample_type; typedef default_memory_manager mem_manager_type; sparse_sigmoid_kernel(const scalar_type g, const scalar_type c) : gamma(g), coef(c) {} sparse_sigmoid_kernel() : gamma(0.1), coef(-1.0) {} sparse_sigmoid_kernel( const sparse_sigmoid_kernel& k ) : gamma(k.gamma), coef(k.coef) {} typedef T type; const scalar_type gamma; const scalar_type coef; scalar_type operator() ( const sample_type& a, const sample_type& b ) const { return std::tanh(gamma*(dot(a,b)) + coef); } sparse_sigmoid_kernel& operator= ( const sparse_sigmoid_kernel& k ) { const_cast<scalar_type&>(gamma) = k.gamma; const_cast<scalar_type&>(coef) = k.coef; return *this; } bool operator== ( const sparse_sigmoid_kernel& k ) const { return (gamma == k.gamma) && (coef == k.coef); } }; template < typename T > void serialize ( const sparse_sigmoid_kernel<T>& item, std::ostream& out ) { try { serialize(item.gamma, out); serialize(item.coef, out); } catch (serialization_error& e) { throw serialization_error(e.info + "\n while serializing object of type sparse_sigmoid_kernel"); } } template < typename T > void deserialize ( sparse_sigmoid_kernel<T>& item, std::istream& in ) { typedef typename T::value_type::second_type scalar_type; try { deserialize(const_cast<scalar_type&>(item.gamma), in); deserialize(const_cast<scalar_type&>(item.coef), in); } catch (serialization_error& e) { throw serialization_error(e.info + "\n while deserializing object of type sparse_sigmoid_kernel"); } } // ---------------------------------------------------------------------------------------- template <typename T> struct sparse_linear_kernel { typedef typename T::value_type::second_type scalar_type; typedef T sample_type; typedef default_memory_manager mem_manager_type; scalar_type operator() ( const sample_type& a, const sample_type& b ) const { return dot(a,b); } bool operator== ( const sparse_linear_kernel& ) const { return true; } }; template < typename T > void serialize ( const sparse_linear_kernel<T>& , std::ostream& ){} template < typename T > void deserialize ( sparse_linear_kernel<T>& , std::istream& ){} // ---------------------------------------------------------------------------------------- template <typename T> struct sparse_histogram_intersection_kernel { typedef typename T::value_type::second_type scalar_type; typedef T sample_type; typedef default_memory_manager mem_manager_type; scalar_type operator() ( const sample_type& a, const sample_type& b ) const { typename sample_type::const_iterator ai = a.begin(); typename sample_type::const_iterator bi = b.begin(); scalar_type sum = 0; while (ai != a.end() && bi != b.end()) { if (ai->first == bi->first) { sum += std::min(ai->second , bi->second); ++ai; ++bi; } else if (ai->first < bi->first) { ++ai; } else { ++bi; } } return sum; } bool operator== ( const sparse_histogram_intersection_kernel& ) const { return true; } }; template < typename T > void serialize ( const sparse_histogram_intersection_kernel<T>& , std::ostream& ){} template < typename T > void deserialize ( sparse_histogram_intersection_kernel<T>& , std::istream& ){} // ---------------------------------------------------------------------------------------- } #endif // DLIB_SVm_SPARSE_KERNEL