// Copyright (C) 2009  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_SVm_SPARSE_KERNEL
#define DLIB_SVm_SPARSE_KERNEL

#include "sparse_kernel_abstract.h"
#include <cmath>
#include <limits>
#include "../algs.h"
#include "../serialize.h"
#include "sparse_vector.h"


namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename T
        >
    struct sparse_radial_basis_kernel
    {
        typedef typename T::value_type::second_type scalar_type;
        typedef T sample_type;
        typedef default_memory_manager mem_manager_type;

        sparse_radial_basis_kernel(const scalar_type g) : gamma(g) {}
        sparse_radial_basis_kernel() : gamma(0.1) {}
        sparse_radial_basis_kernel(
            const sparse_radial_basis_kernel& k
        ) : gamma(k.gamma) {}


        const scalar_type gamma;

        scalar_type operator() (
            const sample_type& a,
            const sample_type& b
        ) const
        { 
            const scalar_type d = distance_squared(a,b);
            return std::exp(-gamma*d);
        }

        sparse_radial_basis_kernel& operator= (
            const sparse_radial_basis_kernel& k
        )
        {
            const_cast<scalar_type&>(gamma) = k.gamma;
            return *this;
        }

        bool operator== (
            const sparse_radial_basis_kernel& k
        ) const
        {
            return gamma == k.gamma;
        }
    };

    template <
        typename T
        >
    void serialize (
        const sparse_radial_basis_kernel<T>& item,
        std::ostream& out
    )
    {
        try
        {
            serialize(item.gamma, out);
        }
        catch (serialization_error& e)
        { 
            throw serialization_error(e.info + "\n   while serializing object of type sparse_radial_basis_kernel"); 
        }
    }

    template <
        typename T
        >
    void deserialize (
        sparse_radial_basis_kernel<T>& item,
        std::istream& in 
    )
    {
        typedef typename T::value_type::second_type scalar_type;
        try
        {
            deserialize(const_cast<scalar_type&>(item.gamma), in);
        }
        catch (serialization_error& e)
        { 
            throw serialization_error(e.info + "\n   while deserializing object of type sparse_radial_basis_kernel"); 
        }
    }

// ----------------------------------------------------------------------------------------

    template <
        typename T
        >
    struct sparse_polynomial_kernel
    {
        typedef typename T::value_type::second_type scalar_type;
        typedef T sample_type;
        typedef default_memory_manager mem_manager_type;

        sparse_polynomial_kernel(const scalar_type g, const scalar_type c, const scalar_type d) : gamma(g), coef(c), degree(d) {}
        sparse_polynomial_kernel() : gamma(1), coef(0), degree(1) {}
        sparse_polynomial_kernel(
            const sparse_polynomial_kernel& k
        ) : gamma(k.gamma), coef(k.coef), degree(k.degree) {}

        typedef T type;
        const scalar_type gamma;
        const scalar_type coef;
        const scalar_type degree;

        scalar_type operator() (
            const sample_type& a,
            const sample_type& b
        ) const
        { 
            return std::pow(gamma*(dot(a,b)) + coef, degree);
        }

        sparse_polynomial_kernel& operator= (
            const sparse_polynomial_kernel& k
        )
        {
            const_cast<scalar_type&>(gamma) = k.gamma;
            const_cast<scalar_type&>(coef) = k.coef;
            const_cast<scalar_type&>(degree) = k.degree;
            return *this;
        }

        bool operator== (
            const sparse_polynomial_kernel& k
        ) const
        {
            return (gamma == k.gamma) && (coef == k.coef) && (degree == k.degree);
        }
    };

    template <
        typename T
        >
    void serialize (
        const sparse_polynomial_kernel<T>& item,
        std::ostream& out
    )
    {
        try
        {
            serialize(item.gamma, out);
            serialize(item.coef, out);
            serialize(item.degree, out);
        }
        catch (serialization_error& e)
        { 
            throw serialization_error(e.info + "\n   while serializing object of type sparse_polynomial_kernel"); 
        }
    }

    template <
        typename T
        >
    void deserialize (
        sparse_polynomial_kernel<T>& item,
        std::istream& in 
    )
    {
        typedef typename T::value_type::second_type scalar_type;
        try
        {
            deserialize(const_cast<scalar_type&>(item.gamma), in);
            deserialize(const_cast<scalar_type&>(item.coef), in);
            deserialize(const_cast<scalar_type&>(item.degree), in);
        }
        catch (serialization_error& e)
        { 
            throw serialization_error(e.info + "\n   while deserializing object of type sparse_polynomial_kernel"); 
        }
    }

// ----------------------------------------------------------------------------------------

    template <
        typename T
        >
    struct sparse_sigmoid_kernel
    {
        typedef typename T::value_type::second_type scalar_type;
        typedef T sample_type;
        typedef default_memory_manager mem_manager_type;

        sparse_sigmoid_kernel(const scalar_type g, const scalar_type c) : gamma(g), coef(c) {}
        sparse_sigmoid_kernel() : gamma(0.1), coef(-1.0) {}
        sparse_sigmoid_kernel(
            const sparse_sigmoid_kernel& k
        ) : gamma(k.gamma), coef(k.coef) {}

        typedef T type;
        const scalar_type gamma;
        const scalar_type coef;

        scalar_type operator() (
            const sample_type& a,
            const sample_type& b
        ) const
        { 
            return std::tanh(gamma*(dot(a,b)) + coef);
        }

        sparse_sigmoid_kernel& operator= (
            const sparse_sigmoid_kernel& k
        )
        {
            const_cast<scalar_type&>(gamma) = k.gamma;
            const_cast<scalar_type&>(coef) = k.coef;
            return *this;
        }

        bool operator== (
            const sparse_sigmoid_kernel& k
        ) const
        {
            return (gamma == k.gamma) && (coef == k.coef);
        }
    };

    template <
        typename T
        >
    void serialize (
        const sparse_sigmoid_kernel<T>& item,
        std::ostream& out
    )
    {
        try
        {
            serialize(item.gamma, out);
            serialize(item.coef, out);
        }
        catch (serialization_error& e)
        { 
            throw serialization_error(e.info + "\n   while serializing object of type sparse_sigmoid_kernel"); 
        }
    }

    template <
        typename T
        >
    void deserialize (
        sparse_sigmoid_kernel<T>& item,
        std::istream& in 
    )
    {
        typedef typename T::value_type::second_type scalar_type;
        try
        {
            deserialize(const_cast<scalar_type&>(item.gamma), in);
            deserialize(const_cast<scalar_type&>(item.coef), in);
        }
        catch (serialization_error& e)
        { 
            throw serialization_error(e.info + "\n   while deserializing object of type sparse_sigmoid_kernel"); 
        }
    }

// ----------------------------------------------------------------------------------------

    template <typename T>
    struct sparse_linear_kernel
    {
        typedef typename T::value_type::second_type scalar_type;
        typedef T sample_type;
        typedef default_memory_manager mem_manager_type;

        scalar_type operator() (
            const sample_type& a,
            const sample_type& b
        ) const
        { 
            return dot(a,b);
        }

        bool operator== (
            const sparse_linear_kernel& 
        ) const
        {
            return true;
        }
    };

    template <
        typename T
        >
    void serialize (
        const sparse_linear_kernel<T>& ,
        std::ostream& 
    ){}

    template <
        typename T
        >
    void deserialize (
        sparse_linear_kernel<T>& ,
        std::istream&  
    ){}

// ----------------------------------------------------------------------------------------

    template <typename T>
    struct sparse_histogram_intersection_kernel
    {
        typedef typename T::value_type::second_type scalar_type;
        typedef T sample_type;
        typedef default_memory_manager mem_manager_type;

        scalar_type operator() (
            const sample_type& a,
            const sample_type& b
        ) const
        { 
            typename sample_type::const_iterator ai = a.begin();
            typename sample_type::const_iterator bi = b.begin();

            scalar_type sum = 0;
            while (ai != a.end() && bi != b.end())
            {
                if (ai->first == bi->first)
                {
                    sum += std::min(ai->second , bi->second);
                    ++ai;
                    ++bi;
                }
                else if (ai->first < bi->first)
                {
                    ++ai;
                }
                else 
                {
                    ++bi;
                }
            }

            return sum;
        }

        bool operator== (
            const sparse_histogram_intersection_kernel& 
        ) const
        {
            return true;
        }
    };

    template <
        typename T
        >
    void serialize (
        const sparse_histogram_intersection_kernel<T>& ,
        std::ostream& 
    ){}

    template <
        typename T
        >
    void deserialize (
        sparse_histogram_intersection_kernel<T>& ,
        std::istream&  
    ){}

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_SVm_SPARSE_KERNEL