// Copyright (C) 2013  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_SCAN_IMAGE_CuSTOM_Hh_
#define DLIB_SCAN_IMAGE_CuSTOM_Hh_

#include "scan_image_custom_abstract.h"
#include "../matrix.h"
#include "../geometry.h"
#include <vector>
#include "../image_processing/full_object_detection.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename Feature_extractor_type
        >
    class scan_image_custom : noncopyable
    {

    public:

        typedef matrix<double,0,1> feature_vector_type;
        typedef Feature_extractor_type feature_extractor_type;

        scan_image_custom (
        );  

        template <
            typename image_type
            >
        void load (
            const image_type& img
        );

        inline bool is_loaded_with_image (
        ) const;

        inline void copy_configuration(
            const feature_extractor_type& fe
        );

        const Feature_extractor_type& get_feature_extractor (
        ) const { return feats; }

        inline void copy_configuration (
            const scan_image_custom& item
        );

        inline long get_num_dimensions (
        ) const;

        void detect (
            const feature_vector_type& w,
            std::vector<std::pair<double, rectangle> >& dets,
            const double thresh
        ) const;

        void get_feature_vector (
            const full_object_detection& obj,
            feature_vector_type& psi
        ) const;

        full_object_detection get_full_object_detection (
            const rectangle& rect,
            const feature_vector_type& w
        ) const;

        const rectangle get_best_matching_rect (
            const rectangle& rect
        ) const;

        inline unsigned long get_num_detection_templates (
        ) const { return 1; }

        inline unsigned long get_num_movable_components_per_detection_template (
        ) const { return 0; }

        template <typename T>
        friend void serialize (
            const scan_image_custom<T>& item,
            std::ostream& out
        );

        template <typename T>
        friend void deserialize (
            scan_image_custom<T>& item,
            std::istream& in 
        );

    private:
        static bool compare_pair_rect (
            const std::pair<double, rectangle>& a,
            const std::pair<double, rectangle>& b
        )
        {
            return a.first < b.first;
        }


        DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(
            has_compute_object_score,
            double, 
            compute_object_score,
            ( const matrix<double,0,1>& w, const rectangle& obj) const
        );

        template <typename fe_type>
        typename enable_if<has_compute_object_score<fe_type> >::type compute_all_rect_scores (
            const fe_type& feats,
            const feature_vector_type& w,
            std::vector<std::pair<double, rectangle> >& dets,
            const double thresh
        ) const
        {
            for (unsigned long i = 0; i < search_rects.size(); ++i)
            {
                const double score = feats.compute_object_score(w, search_rects[i]);
                if (score >= thresh)
                {
                    dets.push_back(std::make_pair(score, search_rects[i]));
                }
            }
        }

        template <typename fe_type>
        typename disable_if<has_compute_object_score<fe_type> >::type compute_all_rect_scores (
            const fe_type& feats,
            const feature_vector_type& w,
            std::vector<std::pair<double, rectangle> >& dets,
            const double thresh
        ) const
        {
            matrix<double,0,1> psi(w.size());
            psi = 0;
            double prev_dot = 0;
            for (unsigned long i = 0; i < search_rects.size(); ++i)
            {
                // Reset these back to zero every so often to avoid the accumulation of
                // rounding error.  Note that the only reason we do this loop in this
                // complex way is to avoid needing to zero the psi vector every iteration.
                if ((i%500) == 499)
                {
                    psi = 0;
                    prev_dot = 0;
                }

                feats.get_feature_vector(search_rects[i], psi);
                const double cur_dot = dot(psi, w);
                const double score = cur_dot - prev_dot;
                if (score >= thresh)
                {
                    dets.push_back(std::make_pair(score, search_rects[i]));
                }
                prev_dot = cur_dot;
            }
        }


        feature_extractor_type feats;
        std::vector<rectangle> search_rects;
        bool loaded_with_image;
    };

// ----------------------------------------------------------------------------------------

    template <typename T>
    void serialize (
        const scan_image_custom<T>& item,
        std::ostream& out
    )
    {
        int version = 1;
        serialize(version, out);
        serialize(item.feats, out);
        serialize(item.search_rects, out);
        serialize(item.loaded_with_image, out);
        serialize(item.get_num_dimensions(), out);
    }

// ----------------------------------------------------------------------------------------

    template <typename T>
    void deserialize (
        scan_image_custom<T>& item,
        std::istream& in 
    )
    {
        int version = 0;
        deserialize(version, in);
        if (version != 1)
            throw serialization_error("Unsupported version found when deserializing a scan_image_custom object.");

        deserialize(item.feats, in);
        deserialize(item.search_rects, in);
        deserialize(item.loaded_with_image, in);

        // When developing some feature extractor, it's easy to accidentally change its
        // number of dimensions and then try to deserialize data from an older version of
        // your extractor into the current code.  This check is here to catch that kind of
        // user error.
        long dims;
        deserialize(dims, in);
        if (item.get_num_dimensions() != dims)
            throw serialization_error("Number of dimensions in serialized scan_image_custom doesn't match the expected number.");
    }

// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
//                         scan_image_custom member functions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------

    template <
        typename Feature_extractor_type
        >
    scan_image_custom<Feature_extractor_type>::
    scan_image_custom (
    ) :
        loaded_with_image(false)
    {
    }

// ----------------------------------------------------------------------------------------

    template <
        typename Feature_extractor_type
        >
    template <
        typename image_type
        >
    void scan_image_custom<Feature_extractor_type>::
    load (
        const image_type& img
    )
    {
        feats.load(img, search_rects);
        loaded_with_image = true;
    }

// ----------------------------------------------------------------------------------------

    template <
        typename Feature_extractor_type
        >
    bool scan_image_custom<Feature_extractor_type>::
    is_loaded_with_image (
    ) const
    {
        return loaded_with_image;
    }

// ----------------------------------------------------------------------------------------

    template <
        typename Feature_extractor_type
        >
    void scan_image_custom<Feature_extractor_type>::
    copy_configuration(
        const feature_extractor_type& fe
    )
    {
        feats.copy_configuration(fe);
    }

// ----------------------------------------------------------------------------------------

    template <
        typename Feature_extractor_type
        >
    void scan_image_custom<Feature_extractor_type>::
    copy_configuration (
        const scan_image_custom& item
    )
    {
        feats.copy_configuration(item.feats);
    }

// ----------------------------------------------------------------------------------------

    template <
        typename Feature_extractor_type
        >
    long scan_image_custom<Feature_extractor_type>::
    get_num_dimensions (
    ) const
    {
        return feats.get_num_dimensions();
    }

// ----------------------------------------------------------------------------------------

    template <
        typename Feature_extractor_type
        >
    void scan_image_custom<Feature_extractor_type>::
    detect (
        const feature_vector_type& w,
        std::vector<std::pair<double, rectangle> >& dets,
        const double thresh
    ) const
    {
        // make sure requires clause is not broken
        DLIB_ASSERT(is_loaded_with_image() &&
                    w.size() >= get_num_dimensions(), 
            "\t void scan_image_custom::detect()"
            << "\n\t Invalid inputs were given to this function "
            << "\n\t is_loaded_with_image(): " << is_loaded_with_image()
            << "\n\t w.size():               " << w.size()
            << "\n\t get_num_dimensions():   " << get_num_dimensions()
            << "\n\t this: " << this
            );
        
        dets.clear();
        compute_all_rect_scores(feats, w,dets,thresh);
        std::sort(dets.rbegin(), dets.rend(), compare_pair_rect);
    }

// ----------------------------------------------------------------------------------------

    template <
        typename Feature_extractor_type
        >
    const rectangle scan_image_custom<Feature_extractor_type>::
    get_best_matching_rect (
        const rectangle& rect
    ) const
    {
        // make sure requires clause is not broken
        DLIB_ASSERT(is_loaded_with_image(),
            "\t const rectangle scan_image_custom::get_best_matching_rect()"
            << "\n\t Invalid inputs were given to this function "
            << "\n\t is_loaded_with_image(): " << is_loaded_with_image()
            << "\n\t this: " << this
            );


        double best_score = -1;
        rectangle best_rect;
        for (unsigned long i = 0; i < search_rects.size(); ++i)
        {
            const double score = (rect.intersect(search_rects[i])).area()/(double)(rect+search_rects[i]).area();
            if (score > best_score)
            {
                best_score = score;
                best_rect = search_rects[i];
            }
        }
        return best_rect;
    }

// ----------------------------------------------------------------------------------------

    template <
        typename Feature_extractor_type
        >
    full_object_detection scan_image_custom<Feature_extractor_type>::
    get_full_object_detection (
        const rectangle& rect,
        const feature_vector_type& /*w*/
    ) const
    {
        return full_object_detection(rect);
    }

// ----------------------------------------------------------------------------------------

    template <
        typename Feature_extractor_type
        >
    void scan_image_custom<Feature_extractor_type>::
    get_feature_vector (
        const full_object_detection& obj,
        feature_vector_type& psi
    ) const
    {
        // make sure requires clause is not broken
        DLIB_ASSERT(is_loaded_with_image() &&
                    psi.size() >= get_num_dimensions() &&
                    obj.num_parts() == 0,
            "\t void scan_image_custom::get_feature_vector()"
            << "\n\t Invalid inputs were given to this function "
            << "\n\t is_loaded_with_image(): " << is_loaded_with_image()
            << "\n\t psi.size():             " << psi.size()
            << "\n\t get_num_dimensions():   " << get_num_dimensions()
            << "\n\t obj.num_parts():                            " << obj.num_parts()
            << "\n\t this: " << this
            );


        feats.get_feature_vector(get_best_matching_rect(obj.get_rect()), psi);
    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_SCAN_IMAGE_CuSTOM_Hh_