// Copyright (C) 2004 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CONDITIONING_CLASS_KERNEl_4_ #define DLIB_CONDITIONING_CLASS_KERNEl_4_ #include "conditioning_class_kernel_abstract.h" #include "../assert.h" #include "../algs.h" namespace dlib { template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > class conditioning_class_kernel_4 { /*! REQUIREMENTS ON pool_size pool_size > 0 this will be the number of nodes contained in our memory pool REQUIREMENTS ON mem_manager mem_manager is an implementation of memory_manager/memory_manager_kernel_abstract.h INITIAL VALUE total == 1 escapes == 1 next == 0 CONVENTION get_total() == total get_count(alphabet_size-1) == escapes if (next != 0) then next == pointer to the start of a linked list and the linked list is terminated by a node with a next pointer of 0. get_count(symbol) == node::count for the node where node::symbol==symbol or 0 if no such node currently exists. if (there is a node for the symbol) then LOW_COUNT(symbol) == the sum of all node's counts in the linked list up to but not including the node for the symbol. get_memory_usage() == global_state.memory_usage !*/ struct node { unsigned short symbol; unsigned short count; node* next; }; public: class global_state_type { public: global_state_type ( ) : memory_usage(pool_size*sizeof(node)+sizeof(global_state_type)) {} private: unsigned long memory_usage; typename mem_manager::template rebind<node>::other pool; friend class conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>; }; conditioning_class_kernel_4 ( global_state_type& global_state_ ); ~conditioning_class_kernel_4 ( ); void clear( ); bool increment_count ( unsigned long symbol, unsigned short amount = 1 ); unsigned long get_count ( unsigned long symbol ) const; inline unsigned long get_total ( ) const; unsigned long get_range ( unsigned long symbol, unsigned long& low_count, unsigned long& high_count, unsigned long& total_count ) const; void get_symbol ( unsigned long target, unsigned long& symbol, unsigned long& low_count, unsigned long& high_count ) const; unsigned long get_memory_usage ( ) const; global_state_type& get_global_state ( ); static unsigned long get_alphabet_size ( ); private: void half_counts ( ); /*! ensures - divides all counts by 2 but ensures that escapes is always at least 1 !*/ // restricted functions conditioning_class_kernel_4(conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>&); // copy constructor conditioning_class_kernel_4& operator=(conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>&); // assignment operator // data members unsigned short total; unsigned short escapes; node* next; global_state_type& global_state; }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>:: conditioning_class_kernel_4 ( global_state_type& global_state_ ) : total(1), escapes(1), next(0), global_state(global_state_) { COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); // update memory usage global_state.memory_usage += sizeof(conditioning_class_kernel_4); } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>:: ~conditioning_class_kernel_4 ( ) { clear(); // update memory usage global_state.memory_usage -= sizeof(conditioning_class_kernel_4); } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > void conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>:: clear( ) { total = 1; escapes = 1; while (next) { node* temp = next; next = next->next; global_state.pool.deallocate(temp); } } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > unsigned long conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>:: get_memory_usage( ) const { return global_state.memory_usage; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > typename conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::global_state_type& conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>:: get_global_state( ) { return global_state; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > bool conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>:: increment_count ( unsigned long symbol, unsigned short amount ) { if (symbol == alphabet_size-1) { // make sure we won't cause any overflow if (total >= 65536 - amount ) half_counts(); escapes += amount; total += amount; return true; } // find the symbol and increment it or add a new node to the list if (next) { node* temp = next; node* previous = 0; while (true) { if (temp->symbol == static_cast<unsigned short>(symbol)) { // make sure we won't cause any overflow if (total >= 65536 - amount ) half_counts(); // we have found the symbol total += amount; temp->count += amount; // if this node now has a count greater than its parent node if (previous && temp->count > previous->count) { // swap the nodes so that the nodes will be in semi-sorted order swap(temp->count,previous->count); swap(temp->symbol,previous->symbol); } return true; } else if (temp->next == 0) { // we did not find the symbol so try to add it to the list if (global_state.pool.get_number_of_allocations() < pool_size) { // make sure we won't cause any overflow if (total >= 65536 - amount ) half_counts(); node* t = global_state.pool.allocate(); t->next = 0; t->symbol = static_cast<unsigned short>(symbol); t->count = amount; temp->next = t; total += amount; return true; } else { // no memory left return false; } } else if (temp->count == 0) { // remove nodes that have a zero count if (previous) { previous->next = temp->next; node* t = temp; temp = temp->next; global_state.pool.deallocate(t); } else { next = temp->next; node* t = temp; temp = temp->next; global_state.pool.deallocate(t); } } else { previous = temp; temp = temp->next; } } // while (true) } // if there aren't any nodes in the list yet then do this instead else { if (global_state.pool.get_number_of_allocations() < pool_size) { // make sure we won't cause any overflow if (total >= 65536 - amount ) half_counts(); next = global_state.pool.allocate(); next->next = 0; next->symbol = static_cast<unsigned short>(symbol); next->count = amount; total += amount; return true; } else { // no memory left return false; } } } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > unsigned long conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>:: get_count ( unsigned long symbol ) const { if (symbol == alphabet_size-1) { return escapes; } else { node* temp = next; while (temp) { if (temp->symbol == symbol) return temp->count; temp = temp->next; } return 0; } } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > unsigned long conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>:: get_alphabet_size ( ) { return alphabet_size; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > unsigned long conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>:: get_total ( ) const { return total; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > unsigned long conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>:: get_range ( unsigned long symbol, unsigned long& low_count, unsigned long& high_count, unsigned long& total_count ) const { if (symbol != alphabet_size-1) { node* temp = next; unsigned long low = 0; while (temp) { if (temp->symbol == static_cast<unsigned short>(symbol)) { high_count = temp->count + low; low_count = low; total_count = total; return temp->count; } low += temp->count; temp = temp->next; } return 0; } else { total_count = total; high_count = total; low_count = total-escapes; return escapes; } } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > void conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>:: get_symbol ( unsigned long target, unsigned long& symbol, unsigned long& low_count, unsigned long& high_count ) const { node* temp = next; unsigned long high = 0; while (true) { if (temp != 0) { high += temp->count; if (target < high) { symbol = temp->symbol; high_count = high; low_count = high - temp->count; return; } temp = temp->next; } else { // this must be the escape symbol symbol = alphabet_size-1; low_count = total-escapes; high_count = total; return; } } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // private member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > void conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>:: half_counts ( ) { total = 0; if (escapes > 1) escapes >>= 1; //divide all counts by 2 node* temp = next; while (temp) { temp->count >>= 1; total += temp->count; temp = temp->next; } total += escapes; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CONDITIONING_CLASS_KERNEl_4_