// Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_RLs_Hh_ #define DLIB_RLs_Hh_ #include "rls_abstract.h" #include "../matrix.h" #include "function.h" namespace dlib { // ---------------------------------------------------------------------------------------- class rls { public: explicit rls( double forget_factor_, double C_ = 1000, bool apply_forget_factor_to_C_ = false ) { // make sure requires clause is not broken DLIB_ASSERT(0 < forget_factor_ && forget_factor_ <= 1 && 0 < C_, "\t rls::rls()" << "\n\t invalid arguments were given to this function" << "\n\t forget_factor_: " << forget_factor_ << "\n\t C_: " << C_ << "\n\t this: " << this ); C = C_; forget_factor = forget_factor_; apply_forget_factor_to_C = apply_forget_factor_to_C_; } rls( ) { C = 1000; forget_factor = 1; apply_forget_factor_to_C = false; } double get_c( ) const { return C; } double get_forget_factor( ) const { return forget_factor; } bool should_apply_forget_factor_to_C ( ) const { return apply_forget_factor_to_C; } template <typename EXP> void train ( const matrix_exp<EXP>& x, double y ) { // make sure requires clause is not broken DLIB_ASSERT(is_col_vector(x) && (get_w().size() == 0 || get_w().size() == x.size()), "\t void rls::train()" << "\n\t invalid arguments were given to this function" << "\n\t is_col_vector(x): " << is_col_vector(x) << "\n\t x.size(): " << x.size() << "\n\t get_w().size(): " << get_w().size() << "\n\t this: " << this ); if (R.size() == 0) { R = identity_matrix<double>(x.size())*C; w.set_size(x.size()); w = 0; } // multiply by forget factor and incorporate x*trans(x) into R. const double l = 1.0/forget_factor; const double temp = 1 + l*trans(x)*R*x; tmp = R*x; R = l*R - l*l*(tmp*trans(tmp))/temp; // Since we multiplied by the forget factor, we need to add (1-forget_factor) of the // identity matrix back in to keep the regularization alive. if (forget_factor != 1 && !apply_forget_factor_to_C) add_eye_to_inv(R, (1-forget_factor)/C); // R should always be symmetric. This line improves numeric stability of this algorithm. if (cnt%10 == 0) R = 0.5*(R + trans(R)); ++cnt; w = w + R*x*(y - trans(x)*w); } const matrix<double,0,1>& get_w( ) const { return w; } template <typename EXP> double operator() ( const matrix_exp<EXP>& x ) const { // make sure requires clause is not broken DLIB_ASSERT(is_col_vector(x) && get_w().size() == x.size(), "\t double rls::operator()()" << "\n\t invalid arguments were given to this function" << "\n\t is_col_vector(x): " << is_col_vector(x) << "\n\t x.size(): " << x.size() << "\n\t get_w().size(): " << get_w().size() << "\n\t this: " << this ); return dot(x,w); } decision_function<linear_kernel<matrix<double,0,1> > > get_decision_function ( ) const { // make sure requires clause is not broken DLIB_ASSERT(get_w().size() != 0, "\t decision_function rls::get_decision_function()" << "\n\t invalid arguments were given to this function" << "\n\t get_w().size(): " << get_w().size() << "\n\t this: " << this ); decision_function<linear_kernel<matrix<double,0,1> > > df; df.alpha.set_size(1); df.basis_vectors.set_size(1); df.b = 0; df.alpha = 1; df.basis_vectors(0) = w; return df; } friend inline void serialize(const rls& item, std::ostream& out) { int version = 2; serialize(version, out); serialize(item.w, out); serialize(item.R, out); serialize(item.C, out); serialize(item.forget_factor, out); serialize(item.cnt, out); serialize(item.apply_forget_factor_to_C, out); } friend inline void deserialize(rls& item, std::istream& in) { int version = 0; deserialize(version, in); if (!(1 <= version && version <= 2)) throw dlib::serialization_error("Unknown version number found while deserializing rls object."); if (version >= 1) { deserialize(item.w, in); deserialize(item.R, in); deserialize(item.C, in); deserialize(item.forget_factor, in); } item.cnt = 0; item.apply_forget_factor_to_C = false; if (version >= 2) { deserialize(item.cnt, in); deserialize(item.apply_forget_factor_to_C, in); } } private: void add_eye_to_inv( matrix<double>& m, double C ) /*! ensures - Let m == inv(M) - this function returns inv(M + C*identity_matrix<double>(m.nr())) !*/ { for (long r = 0; r < m.nr(); ++r) { m = m - colm(m,r)*trans(colm(m,r))/(1/C + m(r,r)); } } matrix<double,0,1> w; matrix<double> R; double C; double forget_factor; int cnt = 0; bool apply_forget_factor_to_C; // This object is here only to avoid reallocation during training. It don't // logically contribute to the state of this object. matrix<double,0,1> tmp; }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_RLs_Hh_