// Copyright (C) 2014  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_TRACK_ASSOCiATION_FUNCTION_Hh_
#define DLIB_TRACK_ASSOCiATION_FUNCTION_Hh_


#include "track_association_function_abstract.h"
#include <vector>
#include <iostream>
#include "../algs.h"
#include "../serialize.h"
#include "assignment_function.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename detection_type
        > 
    class feature_extractor_track_association
    {
    public:
        typedef typename detection_type::track_type track_type;
        typedef typename track_type::feature_vector_type feature_vector_type;

        typedef detection_type lhs_element;
        typedef track_type rhs_element;

        feature_extractor_track_association() : num_dims(0), num_nonnegative(0) {}

        explicit feature_extractor_track_association (
            unsigned long num_dims_,
            unsigned long num_nonnegative_
        ) : num_dims(num_dims_), num_nonnegative(num_nonnegative_) {}

        unsigned long num_features(
        ) const { return num_dims; }

        unsigned long num_nonnegative_weights (
        ) const { return num_nonnegative; }

        void get_features (
            const detection_type& det,
            const track_type& track,
            feature_vector_type& feats
        ) const
        {
            track.get_similarity_features(det, feats);
        }

        friend void serialize (const feature_extractor_track_association& item, std::ostream& out) 
        { 
            serialize(item.num_dims, out);
            serialize(item.num_nonnegative, out);
        }

        friend void deserialize (feature_extractor_track_association& item, std::istream& in) 
        {
            deserialize(item.num_dims, in);
            deserialize(item.num_nonnegative, in);
        }

    private:
        unsigned long num_dims;
        unsigned long num_nonnegative;
    };

// ----------------------------------------------------------------------------------------

    template <
        typename detection_type_
        >
    class track_association_function
    {
    public:

        typedef detection_type_ detection_type;
        typedef typename detection_type::track_type track_type;
        typedef assignment_function<feature_extractor_track_association<detection_type> > association_function_type;

        track_association_function() {}

        track_association_function (
            const association_function_type& assoc_
        ) : assoc(assoc_)
        {
        }

        const association_function_type& get_assignment_function (
        ) const
        {
            return assoc;
        }

        void operator() (
            std::vector<track_type>& tracks,
            const std::vector<detection_type>& dets
        ) const
        {
            std::vector<long> assignments = assoc(dets, tracks);
            std::vector<bool> updated_track(tracks.size(), false);
            // now update all the tracks with the detections that associated to them.
            for (unsigned long i = 0; i < assignments.size(); ++i)
            {
                if (assignments[i] != -1)
                {
                    tracks[assignments[i]].update_track(dets[i]);
                    updated_track[assignments[i]] = true;
                }
                else
                {
                    track_type new_track;
                    new_track.update_track(dets[i]);
                    tracks.push_back(new_track);
                }
            }

            // Now propagate all the tracks that didn't get any detections.
            for (unsigned long i = 0; i < updated_track.size(); ++i)
            {
                if (!updated_track[i])
                    tracks[i].propagate_track();
            }
        }

        friend void serialize (const track_association_function& item, std::ostream& out)
        {
            int version = 1;
            serialize(version, out);
            serialize(item.assoc, out);
        }
        friend void deserialize (track_association_function& item, std::istream& in)
        {
            int version = 0;
            deserialize(version, in);
            if (version != 1)
                throw serialization_error("Unexpected version found while deserializing dlib::track_association_function.");

            deserialize(item.assoc, in);
        }

    private:

        assignment_function<feature_extractor_track_association<detection_type> > assoc;
    };

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_TRACK_ASSOCiATION_FUNCTION_Hh_