// Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_OBJECT_DeTECTOR_Hh_ #define DLIB_OBJECT_DeTECTOR_Hh_ #include "object_detector_abstract.h" #include "../geometry.h" #include <vector> #include "box_overlap_testing.h" #include "full_object_detection.h" namespace dlib { // ---------------------------------------------------------------------------------------- struct rect_detection { double detection_confidence; unsigned long weight_index; rectangle rect; bool operator<(const rect_detection& item) const { return detection_confidence < item.detection_confidence; } }; struct full_detection { double detection_confidence; unsigned long weight_index; full_object_detection rect; bool operator<(const full_detection& item) const { return detection_confidence < item.detection_confidence; } }; // ---------------------------------------------------------------------------------------- template <typename image_scanner_type> struct processed_weight_vector { processed_weight_vector(){} typedef typename image_scanner_type::feature_vector_type feature_vector_type; void init ( const image_scanner_type& ) /*! requires - w has already been assigned its value. Note that the point of this function is to allow an image scanner to overload the processed_weight_vector template and provide some different kind of object as the output of get_detect_argument(). For example, the scan_fhog_pyramid object uses an overload that causes get_detect_argument() to return the special fhog_filterbank object instead of a feature_vector_type. This avoids needing to construct the fhog_filterbank during each call to detect and therefore speeds up detection. !*/ {} // return the first argument to image_scanner_type::detect() const feature_vector_type& get_detect_argument() const { return w; } feature_vector_type w; }; // ---------------------------------------------------------------------------------------- template < typename image_scanner_type_ > class object_detector { public: typedef image_scanner_type_ image_scanner_type; typedef typename image_scanner_type::feature_vector_type feature_vector_type; object_detector ( ); object_detector ( const object_detector& item ); object_detector ( const image_scanner_type& scanner_, const test_box_overlap& overlap_tester_, const feature_vector_type& w_ ); object_detector ( const image_scanner_type& scanner_, const test_box_overlap& overlap_tester_, const std::vector<feature_vector_type>& w_ ); explicit object_detector ( const std::vector<object_detector>& detectors ); unsigned long num_detectors ( ) const { return w.size(); } const feature_vector_type& get_w ( unsigned long idx = 0 ) const { return w[idx].w; } const processed_weight_vector<image_scanner_type>& get_processed_w ( unsigned long idx = 0 ) const { return w[idx]; } const test_box_overlap& get_overlap_tester ( ) const; const image_scanner_type& get_scanner ( ) const; object_detector& operator= ( const object_detector& item ); template < typename image_type > std::vector<rectangle> operator() ( const image_type& img, double adjust_threshold = 0 ); template < typename image_type > void operator() ( const image_type& img, std::vector<std::pair<double, rectangle> >& final_dets, double adjust_threshold = 0 ); template < typename image_type > void operator() ( const image_type& img, std::vector<std::pair<double, full_object_detection> >& final_dets, double adjust_threshold = 0 ); template < typename image_type > void operator() ( const image_type& img, std::vector<full_object_detection>& final_dets, double adjust_threshold = 0 ); // These typedefs are here for backwards compatibility with previous versions of // dlib. typedef ::dlib::rect_detection rect_detection; typedef ::dlib::full_detection full_detection; template < typename image_type > void operator() ( const image_type& img, std::vector<rect_detection>& final_dets, double adjust_threshold = 0 ); template < typename image_type > void operator() ( const image_type& img, std::vector<full_detection>& final_dets, double adjust_threshold = 0 ); template <typename T> friend void serialize ( const object_detector<T>& item, std::ostream& out ); template <typename T> friend void deserialize ( object_detector<T>& item, std::istream& in ); private: bool overlaps_any_box ( const std::vector<rect_detection>& rects, const dlib::rectangle& rect ) const { for (unsigned long i = 0; i < rects.size(); ++i) { if (boxes_overlap(rects[i].rect, rect)) return true; } return false; } test_box_overlap boxes_overlap; std::vector<processed_weight_vector<image_scanner_type> > w; image_scanner_type scanner; }; // ---------------------------------------------------------------------------------------- template <typename T> void serialize ( const object_detector<T>& item, std::ostream& out ) { int version = 2; serialize(version, out); T scanner; scanner.copy_configuration(item.scanner); serialize(scanner, out); serialize(item.boxes_overlap, out); // serialize all the weight vectors serialize(item.w.size(), out); for (unsigned long i = 0; i < item.w.size(); ++i) serialize(item.w[i].w, out); } // ---------------------------------------------------------------------------------------- template <typename T> void deserialize ( object_detector<T>& item, std::istream& in ) { int version = 0; deserialize(version, in); if (version == 1) { deserialize(item.scanner, in); item.w.resize(1); deserialize(item.w[0].w, in); item.w[0].init(item.scanner); deserialize(item.boxes_overlap, in); } else if (version == 2) { deserialize(item.scanner, in); deserialize(item.boxes_overlap, in); unsigned long num_detectors = 0; deserialize(num_detectors, in); item.w.resize(num_detectors); for (unsigned long i = 0; i < item.w.size(); ++i) { deserialize(item.w[i].w, in); item.w[i].init(item.scanner); } } else { throw serialization_error("Unexpected version encountered while deserializing a dlib::object_detector object."); } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // object_detector member functions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename image_scanner_type > object_detector<image_scanner_type>:: object_detector ( ) { } // ---------------------------------------------------------------------------------------- template < typename image_scanner_type > object_detector<image_scanner_type>:: object_detector ( const object_detector& item ) { boxes_overlap = item.boxes_overlap; w = item.w; scanner.copy_configuration(item.scanner); } // ---------------------------------------------------------------------------------------- template < typename image_scanner_type > object_detector<image_scanner_type>:: object_detector ( const image_scanner_type& scanner_, const test_box_overlap& overlap_tester, const feature_vector_type& w_ ) : boxes_overlap(overlap_tester) { // make sure requires clause is not broken DLIB_ASSERT(scanner_.get_num_detection_templates() > 0 && w_.size() == scanner_.get_num_dimensions() + 1, "\t object_detector::object_detector(scanner_,overlap_tester,w_)" << "\n\t Invalid inputs were given to this function " << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates() << "\n\t w_.size(): " << w_.size() << "\n\t scanner_.get_num_dimensions(): " << scanner_.get_num_dimensions() << "\n\t this: " << this ); scanner.copy_configuration(scanner_); w.resize(1); w[0].w = w_; w[0].init(scanner); } // ---------------------------------------------------------------------------------------- template < typename image_scanner_type > object_detector<image_scanner_type>:: object_detector ( const image_scanner_type& scanner_, const test_box_overlap& overlap_tester, const std::vector<feature_vector_type>& w_ ) : boxes_overlap(overlap_tester) { // make sure requires clause is not broken DLIB_CASSERT(scanner_.get_num_detection_templates() > 0 && w_.size() > 0, "\t object_detector::object_detector(scanner_,overlap_tester,w_)" << "\n\t Invalid inputs were given to this function " << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates() << "\n\t w_.size(): " << w_.size() << "\n\t this: " << this ); for (unsigned long i = 0; i < w_.size(); ++i) { DLIB_CASSERT(w_[i].size() == scanner_.get_num_dimensions() + 1, "\t object_detector::object_detector(scanner_,overlap_tester,w_)" << "\n\t Invalid inputs were given to this function " << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates() << "\n\t w_["<<i<<"].size(): " << w_[i].size() << "\n\t scanner_.get_num_dimensions(): " << scanner_.get_num_dimensions() << "\n\t this: " << this ); } scanner.copy_configuration(scanner_); w.resize(w_.size()); for (unsigned long i = 0; i < w.size(); ++i) { w[i].w = w_[i]; w[i].init(scanner); } } // ---------------------------------------------------------------------------------------- template < typename image_scanner_type > object_detector<image_scanner_type>:: object_detector ( const std::vector<object_detector>& detectors ) { DLIB_CASSERT(detectors.size() != 0, "\t object_detector::object_detector(detectors)" << "\n\t Invalid inputs were given to this function " << "\n\t this: " << this ); std::vector<feature_vector_type> weights; weights.reserve(detectors.size()); for (unsigned long i = 0; i < detectors.size(); ++i) { for (unsigned long j = 0; j < detectors[i].num_detectors(); ++j) weights.push_back(detectors[i].get_w(j)); } *this = object_detector(detectors[0].get_scanner(), detectors[0].get_overlap_tester(), weights); } // ---------------------------------------------------------------------------------------- template < typename image_scanner_type > object_detector<image_scanner_type>& object_detector<image_scanner_type>:: operator= ( const object_detector& item ) { if (this == &item) return *this; boxes_overlap = item.boxes_overlap; w = item.w; scanner.copy_configuration(item.scanner); return *this; } // ---------------------------------------------------------------------------------------- template < typename image_scanner_type > template < typename image_type > void object_detector<image_scanner_type>:: operator() ( const image_type& img, std::vector<rect_detection>& final_dets, double adjust_threshold ) { scanner.load(img); std::vector<std::pair<double, rectangle> > dets; std::vector<rect_detection> dets_accum; for (unsigned long i = 0; i < w.size(); ++i) { const double thresh = w[i].w(scanner.get_num_dimensions()); scanner.detect(w[i].get_detect_argument(), dets, thresh + adjust_threshold); for (unsigned long j = 0; j < dets.size(); ++j) { rect_detection temp; temp.detection_confidence = dets[j].first-thresh; temp.weight_index = i; temp.rect = dets[j].second; dets_accum.push_back(temp); } } // Do non-max suppression final_dets.clear(); if (w.size() > 1) std::sort(dets_accum.rbegin(), dets_accum.rend()); for (unsigned long i = 0; i < dets_accum.size(); ++i) { if (overlaps_any_box(final_dets, dets_accum[i].rect)) continue; final_dets.push_back(dets_accum[i]); } } // ---------------------------------------------------------------------------------------- template < typename image_scanner_type > template < typename image_type > void object_detector<image_scanner_type>:: operator() ( const image_type& img, std::vector<full_detection>& final_dets, double adjust_threshold ) { std::vector<rect_detection> dets; (*this)(img,dets,adjust_threshold); final_dets.resize(dets.size()); // convert all the rectangle detections into full_object_detections. for (unsigned long i = 0; i < dets.size(); ++i) { final_dets[i].detection_confidence = dets[i].detection_confidence; final_dets[i].weight_index = dets[i].weight_index; final_dets[i].rect = scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w); } } // ---------------------------------------------------------------------------------------- template < typename image_scanner_type > template < typename image_type > std::vector<rectangle> object_detector<image_scanner_type>:: operator() ( const image_type& img, double adjust_threshold ) { std::vector<rect_detection> dets; (*this)(img,dets,adjust_threshold); std::vector<rectangle> final_dets(dets.size()); for (unsigned long i = 0; i < dets.size(); ++i) final_dets[i] = dets[i].rect; return final_dets; } // ---------------------------------------------------------------------------------------- template < typename image_scanner_type > template < typename image_type > void object_detector<image_scanner_type>:: operator() ( const image_type& img, std::vector<std::pair<double, rectangle> >& final_dets, double adjust_threshold ) { std::vector<rect_detection> dets; (*this)(img,dets,adjust_threshold); final_dets.resize(dets.size()); for (unsigned long i = 0; i < dets.size(); ++i) final_dets[i] = std::make_pair(dets[i].detection_confidence,dets[i].rect); } // ---------------------------------------------------------------------------------------- template < typename image_scanner_type > template < typename image_type > void object_detector<image_scanner_type>:: operator() ( const image_type& img, std::vector<std::pair<double, full_object_detection> >& final_dets, double adjust_threshold ) { std::vector<rect_detection> dets; (*this)(img,dets,adjust_threshold); final_dets.clear(); final_dets.reserve(dets.size()); // convert all the rectangle detections into full_object_detections. for (unsigned long i = 0; i < dets.size(); ++i) { final_dets.push_back(std::make_pair(dets[i].detection_confidence, scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w))); } } // ---------------------------------------------------------------------------------------- template < typename image_scanner_type > template < typename image_type > void object_detector<image_scanner_type>:: operator() ( const image_type& img, std::vector<full_object_detection>& final_dets, double adjust_threshold ) { std::vector<rect_detection> dets; (*this)(img,dets,adjust_threshold); final_dets.clear(); final_dets.reserve(dets.size()); // convert all the rectangle detections into full_object_detections. for (unsigned long i = 0; i < dets.size(); ++i) { final_dets.push_back(scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w)); } } // ---------------------------------------------------------------------------------------- template < typename image_scanner_type > const test_box_overlap& object_detector<image_scanner_type>:: get_overlap_tester ( ) const { return boxes_overlap; } // ---------------------------------------------------------------------------------------- template < typename image_scanner_type > const image_scanner_type& object_detector<image_scanner_type>:: get_scanner ( ) const { return scanner; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_OBJECT_DeTECTOR_Hh_