// Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_ #define DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_ #include <vector> #include "../matrix.h" #include "cross_validate_multiclass_trainer_abstract.h" #include <sstream> namespace dlib { // ---------------------------------------------------------------------------------------- template < typename dec_funct_type, typename sample_type, typename label_type > const matrix<double> test_multiclass_decision_function ( const dec_funct_type& dec_funct, const std::vector<sample_type>& x_test, const std::vector<label_type>& y_test ) { // make sure requires clause is not broken DLIB_ASSERT( is_learning_problem(x_test,y_test) == true, "\tmatrix test_multiclass_decision_function()" << "\n\t invalid inputs were given to this function" << "\n\t is_learning_problem(x_test,y_test): " << is_learning_problem(x_test,y_test)); const std::vector<label_type> all_labels = dec_funct.get_labels(); // make a lookup table that maps from labels to their index in all_labels std::map<label_type,unsigned long> label_to_int; for (unsigned long i = 0; i < all_labels.size(); ++i) label_to_int[all_labels[i]] = i; matrix<double, 0, 0, typename dec_funct_type::mem_manager_type> res; res.set_size(all_labels.size(), all_labels.size()); res = 0; typename std::map<label_type,unsigned long>::const_iterator iter; // now test this trained object for (unsigned long i = 0; i < x_test.size(); ++i) { iter = label_to_int.find(y_test[i]); // ignore samples with labels that the decision function doesn't know about. if (iter == label_to_int.end()) continue; const unsigned long truth = iter->second; const unsigned long pred = label_to_int[dec_funct(x_test[i])]; res(truth,pred) += 1; } return res; } // ---------------------------------------------------------------------------------------- class cross_validation_error : public dlib::error { public: cross_validation_error(const std::string& msg) : dlib::error(msg){}; }; template < typename trainer_type, typename sample_type, typename label_type > const matrix<double> cross_validate_multiclass_trainer ( const trainer_type& trainer, const std::vector<sample_type>& x, const std::vector<label_type>& y, const long folds ) { typedef typename trainer_type::mem_manager_type mem_manager_type; // make sure requires clause is not broken DLIB_ASSERT(is_learning_problem(x,y) == true && 1 < folds && folds <= static_cast<long>(x.size()), "\tmatrix cross_validate_multiclass_trainer()" << "\n\t invalid inputs were given to this function" << "\n\t x.size(): " << x.size() << "\n\t folds: " << folds << "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y) ); const std::vector<label_type> all_labels = select_all_distinct_labels(y); // count the number of times each label shows up std::map<label_type,long> label_counts; for (unsigned long i = 0; i < y.size(); ++i) label_counts[y[i]] += 1; // figure out how many samples from each class will be in the test and train splits std::map<label_type,long> num_in_test, num_in_train; for (typename std::map<label_type,long>::iterator i = label_counts.begin(); i != label_counts.end(); ++i) { const long in_test = i->second/folds; if (in_test == 0) { std::ostringstream sout; sout << "In dlib::cross_validate_multiclass_trainer(), the number of folds was larger" << std::endl; sout << "than the number of elements of one of the training classes." << std::endl; sout << " folds: "<< folds << std::endl; sout << " size of class " << i->first << ": "<< i->second << std::endl; throw cross_validation_error(sout.str()); } num_in_test[i->first] = in_test; num_in_train[i->first] = i->second - in_test; } std::vector<sample_type> x_test, x_train; std::vector<label_type> y_test, y_train; matrix<double, 0, 0, mem_manager_type> res; std::map<label_type,long> next_test_idx; for (unsigned long i = 0; i < all_labels.size(); ++i) next_test_idx[all_labels[i]] = 0; label_type label; for (long i = 0; i < folds; ++i) { x_test.clear(); y_test.clear(); x_train.clear(); y_train.clear(); // load up the test samples for (unsigned long j = 0; j < all_labels.size(); ++j) { label = all_labels[j]; long next = next_test_idx[label]; long cur = 0; const long num_needed = num_in_test[label]; while (cur < num_needed) { if (y[next] == label) { x_test.push_back(x[next]); y_test.push_back(label); ++cur; } next = (next + 1)%x.size(); } next_test_idx[label] = next; } // load up the training samples for (unsigned long j = 0; j < all_labels.size(); ++j) { label = all_labels[j]; long next = next_test_idx[label]; long cur = 0; const long num_needed = num_in_train[label]; while (cur < num_needed) { if (y[next] == label) { x_train.push_back(x[next]); y_train.push_back(label); ++cur; } next = (next + 1)%x.size(); } } try { // do the training and testing res += test_multiclass_decision_function(trainer.train(x_train,y_train),x_test,y_test); } catch (invalid_nu_error&) { // just ignore cases which result in an invalid nu } } // for (long i = 0; i < folds; ++i) return res; } } // ---------------------------------------------------------------------------------------- #endif // DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_