// Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_SEQUENCE_LAbELER_H_h_ #define DLIB_SEQUENCE_LAbELER_H_h_ #include "sequence_labeler_abstract.h" #include "../matrix.h" #include <vector> #include "../optimization/find_max_factor_graph_viterbi.h" namespace dlib { // ---------------------------------------------------------------------------------------- namespace fe_helpers { template <typename EXP> struct dot_functor { dot_functor(const matrix_exp<EXP>& lambda_) : lambda(lambda_), value(0) {} inline void operator() ( unsigned long feat_index ) { value += lambda(feat_index); } inline void operator() ( unsigned long feat_index, double feat_value ) { value += feat_value*lambda(feat_index); } const matrix_exp<EXP>& lambda; double value; }; template <typename feature_extractor, typename EXP, typename sequence_type, typename EXP2> double dot( const matrix_exp<EXP>& lambda, const feature_extractor& fe, const sequence_type& sequence, const matrix_exp<EXP2>& candidate_labeling, unsigned long position ) { dot_functor<EXP> dot(lambda); fe.get_features(dot, sequence, candidate_labeling, position); return dot.value; } } // ---------------------------------------------------------------------------------------- namespace impl { DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST( has_reject_labeling, bool, template reject_labeling<matrix<unsigned long> >, (const typename T::sequence_type&, const matrix_exp<matrix<unsigned long> >&, unsigned long)const ); template <typename feature_extractor, typename EXP, typename sequence_type> typename enable_if<has_reject_labeling<feature_extractor>,bool>::type call_reject_labeling_if_exists ( const feature_extractor& fe, const sequence_type& x, const matrix_exp<EXP>& y, unsigned long position ) { return fe.reject_labeling(x, y, position); } template <typename feature_extractor, typename EXP, typename sequence_type> typename disable_if<has_reject_labeling<feature_extractor>,bool>::type call_reject_labeling_if_exists ( const feature_extractor& , const sequence_type& , const matrix_exp<EXP>& , unsigned long ) { return false; } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename feature_extractor > typename enable_if<dlib::impl::has_reject_labeling<feature_extractor>,bool>::type contains_invalid_labeling ( const feature_extractor& fe, const typename feature_extractor::sequence_type& x, const std::vector<unsigned long>& y ) { if (x.size() != y.size()) return true; matrix<unsigned long,0,1> node_states; for (unsigned long i = 0; i < x.size(); ++i) { node_states.set_size(std::min(fe.order(),i) + 1); for (unsigned long j = 0; j < (unsigned long)node_states.size(); ++j) node_states(j) = y[i-j]; if (fe.reject_labeling(x, node_states, i)) return true; } return false; } // ---------------------------------------------------------------------------------------- template < typename feature_extractor > typename disable_if<dlib::impl::has_reject_labeling<feature_extractor>,bool>::type contains_invalid_labeling ( const feature_extractor& , const typename feature_extractor::sequence_type& x, const std::vector<unsigned long>& y ) { if (x.size() != y.size()) return true; return false; } // ---------------------------------------------------------------------------------------- template < typename feature_extractor > bool contains_invalid_labeling ( const feature_extractor& fe, const std::vector<typename feature_extractor::sequence_type>& x, const std::vector<std::vector<unsigned long> >& y ) { if (x.size() != y.size()) return true; for (unsigned long i = 0; i < x.size(); ++i) { if (contains_invalid_labeling(fe,x[i],y[i])) return true; } return false; } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename feature_extractor > class sequence_labeler { public: typedef typename feature_extractor::sequence_type sample_sequence_type; typedef std::vector<unsigned long> labeled_sequence_type; private: class map_prob { public: unsigned long order() const { return fe.order(); } unsigned long num_states() const { return fe.num_labels(); } map_prob( const sample_sequence_type& x_, const feature_extractor& fe_, const matrix<double,0,1>& weights_ ) : sequence(x_), fe(fe_), weights(weights_) { } unsigned long number_of_nodes( ) const { return sequence.size(); } template < typename EXP > double factor_value ( unsigned long node_id, const matrix_exp<EXP>& node_states ) const { if (dlib::impl::call_reject_labeling_if_exists(fe, sequence, node_states, node_id)) return -std::numeric_limits<double>::infinity(); return fe_helpers::dot(weights, fe, sequence, node_states, node_id); } const sample_sequence_type& sequence; const feature_extractor& fe; const matrix<double,0,1>& weights; }; public: sequence_labeler() { weights.set_size(fe.num_features()); weights = 0; } explicit sequence_labeler( const matrix<double,0,1>& weights_ ) : weights(weights_) { // make sure requires clause is not broken DLIB_ASSERT(fe.num_features() == static_cast<unsigned long>(weights_.size()), "\t sequence_labeler::sequence_labeler(weights_)" << "\n\t These sizes should match" << "\n\t fe.num_features(): " << fe.num_features() << "\n\t weights_.size(): " << weights_.size() << "\n\t this: " << this ); } sequence_labeler( const matrix<double,0,1>& weights_, const feature_extractor& fe_ ) : fe(fe_), weights(weights_) { // make sure requires clause is not broken DLIB_ASSERT(fe_.num_features() == static_cast<unsigned long>(weights_.size()), "\t sequence_labeler::sequence_labeler(weights_,fe_)" << "\n\t These sizes should match" << "\n\t fe_.num_features(): " << fe_.num_features() << "\n\t weights_.size(): " << weights_.size() << "\n\t this: " << this ); } const feature_extractor& get_feature_extractor ( ) const { return fe; } const matrix<double,0,1>& get_weights ( ) const { return weights; } unsigned long num_labels ( ) const { return fe.num_labels(); } labeled_sequence_type operator() ( const sample_sequence_type& x ) const { // make sure requires clause is not broken DLIB_ASSERT(num_labels() > 0, "\t labeled_sequence_type sequence_labeler::operator()(x)" << "\n\t You can't have no labels." << "\n\t this: " << this ); labeled_sequence_type y; find_max_factor_graph_viterbi(map_prob(x,fe,weights), y); return y; } void label_sequence ( const sample_sequence_type& x, labeled_sequence_type& y ) const { // make sure requires clause is not broken DLIB_ASSERT(num_labels() > 0, "\t void sequence_labeler::label_sequence(x,y)" << "\n\t You can't have no labels." << "\n\t this: " << this ); find_max_factor_graph_viterbi(map_prob(x,fe,weights), y); } private: feature_extractor fe; matrix<double,0,1> weights; }; // ---------------------------------------------------------------------------------------- template < typename feature_extractor > void serialize ( const sequence_labeler<feature_extractor>& item, std::ostream& out ) { serialize(item.get_feature_extractor(), out); serialize(item.get_weights(), out); } // ---------------------------------------------------------------------------------------- template < typename feature_extractor > void deserialize ( sequence_labeler<feature_extractor>& item, std::istream& in ) { feature_extractor fe; matrix<double,0,1> weights; deserialize(fe, in); deserialize(weights, in); item = sequence_labeler<feature_extractor>(weights, fe); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_SEQUENCE_LAbELER_H_h_