// Copyright (C) 2014 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_TRACK_ASSOCiATION_FUNCTION_Hh_ #define DLIB_TRACK_ASSOCiATION_FUNCTION_Hh_ #include "track_association_function_abstract.h" #include <vector> #include <iostream> #include "../algs.h" #include "../serialize.h" #include "assignment_function.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename detection_type > class feature_extractor_track_association { public: typedef typename detection_type::track_type track_type; typedef typename track_type::feature_vector_type feature_vector_type; typedef detection_type lhs_element; typedef track_type rhs_element; feature_extractor_track_association() : num_dims(0), num_nonnegative(0) {} explicit feature_extractor_track_association ( unsigned long num_dims_, unsigned long num_nonnegative_ ) : num_dims(num_dims_), num_nonnegative(num_nonnegative_) {} unsigned long num_features( ) const { return num_dims; } unsigned long num_nonnegative_weights ( ) const { return num_nonnegative; } void get_features ( const detection_type& det, const track_type& track, feature_vector_type& feats ) const { track.get_similarity_features(det, feats); } friend void serialize (const feature_extractor_track_association& item, std::ostream& out) { serialize(item.num_dims, out); serialize(item.num_nonnegative, out); } friend void deserialize (feature_extractor_track_association& item, std::istream& in) { deserialize(item.num_dims, in); deserialize(item.num_nonnegative, in); } private: unsigned long num_dims; unsigned long num_nonnegative; }; // ---------------------------------------------------------------------------------------- template < typename detection_type_ > class track_association_function { public: typedef detection_type_ detection_type; typedef typename detection_type::track_type track_type; typedef assignment_function<feature_extractor_track_association<detection_type> > association_function_type; track_association_function() {} track_association_function ( const association_function_type& assoc_ ) : assoc(assoc_) { } const association_function_type& get_assignment_function ( ) const { return assoc; } void operator() ( std::vector<track_type>& tracks, const std::vector<detection_type>& dets ) const { std::vector<long> assignments = assoc(dets, tracks); std::vector<bool> updated_track(tracks.size(), false); // now update all the tracks with the detections that associated to them. for (unsigned long i = 0; i < assignments.size(); ++i) { if (assignments[i] != -1) { tracks[assignments[i]].update_track(dets[i]); updated_track[assignments[i]] = true; } else { track_type new_track; new_track.update_track(dets[i]); tracks.push_back(new_track); } } // Now propagate all the tracks that didn't get any detections. for (unsigned long i = 0; i < updated_track.size(); ++i) { if (!updated_track[i]) tracks[i].propagate_track(); } } friend void serialize (const track_association_function& item, std::ostream& out) { int version = 1; serialize(version, out); serialize(item.assoc, out); } friend void deserialize (track_association_function& item, std::istream& in) { int version = 0; deserialize(version, in); if (version != 1) throw serialization_error("Unexpected version found while deserializing dlib::track_association_function."); deserialize(item.assoc, in); } private: assignment_function<feature_extractor_track_association<detection_type> > assoc; }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_TRACK_ASSOCiATION_FUNCTION_Hh_