001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.ml.neuralnet;
019
020import java.io.Serializable;
021import java.io.ObjectInputStream;
022import java.util.concurrent.atomic.AtomicReference;
023import org.apache.commons.math3.util.Precision;
024import org.apache.commons.math3.exception.DimensionMismatchException;
025
026
027/**
028 * Describes a neuron element of a neural network.
029 *
030 * This class aims to be thread-safe.
031 *
032 * @since 3.3
033 */
034public class Neuron implements Serializable {
035    /** Serializable. */
036    private static final long serialVersionUID = 20130207L;
037    /** Identifier. */
038    private final long identifier;
039    /** Length of the feature set. */
040    private final int size;
041    /** Neuron data. */
042    private final AtomicReference<double[]> features;
043
044    /**
045     * Creates a neuron.
046     * The size of the feature set is fixed to the length of the given
047     * argument.
048     * <br/>
049     * Constructor is package-private: Neurons must be
050     * {@link Network#createNeuron(double[]) created} by the network
051     * instance to which they will belong.
052     *
053     * @param identifier Identifier (assigned by the {@link Network}).
054     * @param features Initial values of the feature set.
055     */
056    Neuron(long identifier,
057           double[] features) {
058        this.identifier = identifier;
059        this.size = features.length;
060        this.features = new AtomicReference<double[]>(features.clone());
061    }
062
063    /**
064     * Gets the neuron's identifier.
065     *
066     * @return the identifier.
067     */
068    public long getIdentifier() {
069        return identifier;
070    }
071
072    /**
073     * Gets the length of the feature set.
074     *
075     * @return the number of features.
076     */
077    public int getSize() {
078        return size;
079    }
080
081    /**
082     * Gets the neuron's features.
083     *
084     * @return a copy of the neuron's features.
085     */
086    public double[] getFeatures() {
087        return features.get().clone();
088    }
089
090    /**
091     * Tries to atomically update the neuron's features.
092     * Update will be performed only if the expected values match the
093     * current values.<br/>
094     * In effect, when concurrent threads call this method, the state
095     * could be modified by one, so that it does not correspond to the
096     * the state assumed by another.
097     * Typically, a caller {@link #getFeatures() retrieves the current state},
098     * and uses it to compute the new state.
099     * During this computation, another thread might have done the same
100     * thing, and updated the state: If the current thread were to proceed
101     * with its own update, it would overwrite the new state (which might
102     * already have been used by yet other threads).
103     * To prevent this, the method does not perform the update when a
104     * concurrent modification has been detected, and returns {@code false}.
105     * When this happens, the caller should fetch the new current state,
106     * redo its computation, and call this method again.
107     *
108     * @param expect Current values of the features, as assumed by the caller.
109     * Update will never succeed if the contents of this array does not match
110     * the values returned by {@link #getFeatures()}.
111     * @param update Features's new values.
112     * @return {@code true} if the update was successful, {@code false}
113     * otherwise.
114     * @throws DimensionMismatchException if the length of {@code update} is
115     * not the same as specified in the {@link #Neuron(long,double[])
116     * constructor}.
117     */
118    public boolean compareAndSetFeatures(double[] expect,
119                                         double[] update) {
120        if (update.length != size) {
121            throw new DimensionMismatchException(update.length, size);
122        }
123
124        // Get the internal reference. Note that this must not be a copy;
125        // otherwise the "compareAndSet" below will always fail.
126        final double[] current = features.get();
127        if (!containSameValues(current, expect)) {
128            // Some other thread already modified the state.
129            return false;
130        }
131
132        if (features.compareAndSet(current, update.clone())) {
133            // The current thread could atomically update the state.
134            return true;
135        } else {
136            // Some other thread came first.
137            return false;
138        }
139    }
140
141    /**
142     * Checks whether the contents of both arrays is the same.
143     *
144     * @param current Current values.
145     * @param expect Expected values.
146     * @throws DimensionMismatchException if the length of {@code expected}
147     * is not the same as specified in the {@link #Neuron(long,double[])
148     * constructor}.
149     * @return {@code true} if the arrays contain the same values.
150     */
151    private boolean containSameValues(double[] current,
152                                      double[] expect) {
153        if (expect.length != size) {
154            throw new DimensionMismatchException(expect.length, size);
155        }
156
157        for (int i = 0; i < size; i++) {
158            if (!Precision.equals(current[i], expect[i])) {
159                return false;
160            }
161        }
162        return true;
163    }
164
165    /**
166     * Prevents proxy bypass.
167     *
168     * @param in Input stream.
169     */
170    private void readObject(ObjectInputStream in) {
171        throw new IllegalStateException();
172    }
173
174    /**
175     * Custom serialization.
176     *
177     * @return the proxy instance that will be actually serialized.
178     */
179    private Object writeReplace() {
180        return new SerializationProxy(identifier,
181                                      features.get());
182    }
183
184    /**
185     * Serialization.
186     */
187    private static class SerializationProxy implements Serializable {
188        /** Serializable. */
189        private static final long serialVersionUID = 20130207L;
190        /** Features. */
191        private final double[] features;
192        /** Identifier. */
193        private final long identifier;
194
195        /**
196         * @param identifier Identifier.
197         * @param features Features.
198         */
199        SerializationProxy(long identifier,
200                           double[] features) {
201            this.identifier = identifier;
202            this.features = features;
203        }
204
205        /**
206         * Custom serialization.
207         *
208         * @return the {@link Neuron} for which this instance is the proxy.
209         */
210        private Object readResolve() {
211            return new Neuron(identifier,
212                              features);
213        }
214    }
215}