View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math3.ml.neuralnet;
19  
20  import java.io.Serializable;
21  import java.io.ObjectInputStream;
22  import java.util.NoSuchElementException;
23  import java.util.List;
24  import java.util.ArrayList;
25  import java.util.Set;
26  import java.util.HashSet;
27  import java.util.Collection;
28  import java.util.Iterator;
29  import java.util.Comparator;
30  import java.util.Collections;
31  import java.util.concurrent.ConcurrentHashMap;
32  import java.util.concurrent.atomic.AtomicLong;
33  import org.apache.commons.math3.exception.DimensionMismatchException;
34  import org.apache.commons.math3.exception.MathIllegalStateException;
35  
36  /**
37   * Neural network, composed of {@link Neuron} instances and the links
38   * between them.
39   *
40   * Although updating a neuron's state is thread-safe, modifying the
41   * network's topology (adding or removing links) is not.
42   *
43   * @since 3.3
44   */
45  public class Network
46      implements Iterable<Neuron>,
47                 Serializable {
48      /** Serializable. */
49      private static final long serialVersionUID = 20130207L;
50      /** Neurons. */
51      private final ConcurrentHashMap<Long, Neuron> neuronMap
52          = new ConcurrentHashMap<Long, Neuron>();
53      /** Next available neuron identifier. */
54      private final AtomicLong nextId;
55      /** Neuron's features set size. */
56      private final int featureSize;
57      /** Links. */
58      private final ConcurrentHashMap<Long, Set<Long>> linkMap
59          = new ConcurrentHashMap<Long, Set<Long>>();
60  
61      /**
62       * Comparator that prescribes an order of the neurons according
63       * to the increasing order of their identifier.
64       */
65      public static class NeuronIdentifierComparator
66          implements Comparator<Neuron>,
67                     Serializable {
68          /** Version identifier. */
69          private static final long serialVersionUID = 20130207L;
70  
71          /** {@inheritDoc} */
72          public int compare(Neuron a,
73                             Neuron b) {
74              final long aId = a.getIdentifier();
75              final long bId = b.getIdentifier();
76              return aId < bId ? -1 :
77                  aId > bId ? 1 : 0;
78          }
79      }
80  
81      /**
82       * Constructor with restricted access, solely used for deserialization.
83       *
84       * @param nextId Next available identifier.
85       * @param featureSize Number of features.
86       * @param neuronList Neurons.
87       * @param neighbourIdList Links associated to each of the neurons in
88       * {@code neuronList}.
89       * @throws MathIllegalStateException if an inconsistency is detected
90       * (which probably means that the serialized form has been corrupted).
91       */
92      Network(long nextId,
93              int featureSize,
94              Neuron[] neuronList,
95              long[][] neighbourIdList) {
96          final int numNeurons = neuronList.length;
97          if (numNeurons != neighbourIdList.length) {
98              throw new MathIllegalStateException();
99          }
100 
101         for (int i = 0; i < numNeurons; i++) {
102             final Neuron n = neuronList[i];
103             final long id = n.getIdentifier();
104             if (id >= nextId) {
105                 throw new MathIllegalStateException();
106             }
107             neuronMap.put(id, n);
108             linkMap.put(id, new HashSet<Long>());
109         }
110 
111         for (int i = 0; i < numNeurons; i++) {
112             final long aId = neuronList[i].getIdentifier();
113             final Set<Long> aLinks = linkMap.get(aId);
114             for (Long bId : neighbourIdList[i]) {
115                 if (neuronMap.get(bId) == null) {
116                     throw new MathIllegalStateException();
117                 }
118                 addLinkToLinkSet(aLinks, bId);
119             }
120         }
121 
122         this.nextId = new AtomicLong(nextId);
123         this.featureSize = featureSize;
124     }
125 
126     /**
127      * @param initialIdentifier Identifier for the first neuron that
128      * will be added to this network.
129      * @param featureSize Size of the neuron's features.
130      */
131     public Network(long initialIdentifier,
132                    int featureSize) {
133         nextId = new AtomicLong(initialIdentifier);
134         this.featureSize = featureSize;
135     }
136 
137     /**
138      * {@inheritDoc}
139      */
140     public Iterator<Neuron> iterator() {
141         return neuronMap.values().iterator();
142     }
143 
144     /**
145      * Creates a list of the neurons, sorted in a custom order.
146      *
147      * @param comparator {@link Comparator} used for sorting the neurons.
148      * @return a list of neurons, sorted in the order prescribed by the
149      * given {@code comparator}.
150      * @see NeuronIdentifierComparator
151      */
152     public Collection<Neuron> getNeurons(Comparator<Neuron> comparator) {
153         final List<Neuron> neurons = new ArrayList<Neuron>();
154         neurons.addAll(neuronMap.values());
155 
156         Collections.sort(neurons, comparator);
157 
158         return neurons;
159     }
160 
161     /**
162      * Creates a neuron and assigns it a unique identifier.
163      *
164      * @param features Initial values for the neuron's features.
165      * @return the neuron's identifier.
166      * @throws DimensionMismatchException if the length of {@code features}
167      * is different from the expected size (as set by the
168      * {@link #Network(long,int) constructor}).
169      */
170     public long createNeuron(double[] features) {
171         if (features.length != featureSize) {
172             throw new DimensionMismatchException(features.length, featureSize);
173         }
174 
175         final long id = createNextId();
176         neuronMap.put(id, new Neuron(id, features));
177         linkMap.put(id, new HashSet<Long>());
178         return id;
179     }
180 
181     /**
182      * Deletes a neuron.
183      * Links from all neighbours to the removed neuron will also be
184      * {@link #deleteLink(Neuron,Neuron) deleted}.
185      *
186      * @param neuron Neuron to be removed from this network.
187      * @throws NoSuchElementException if {@code n} does not belong to
188      * this network.
189      */
190     public void deleteNeuron(Neuron neuron) {
191         final Collection<Neuron> neighbours = getNeighbours(neuron);
192 
193         // Delete links to from neighbours.
194         for (Neuron n : neighbours) {
195             deleteLink(n, neuron);
196         }
197 
198         // Remove neuron.
199         neuronMap.remove(neuron.getIdentifier());
200     }
201 
202     /**
203      * Gets the size of the neurons' features set.
204      *
205      * @return the size of the features set.
206      */
207     public int getFeaturesSize() {
208         return featureSize;
209     }
210 
211     /**
212      * Adds a link from neuron {@code a} to neuron {@code b}.
213      * Note: the link is not bi-directional; if a bi-directional link is
214      * required, an additional call must be made with {@code a} and
215      * {@code b} exchanged in the argument list.
216      *
217      * @param a Neuron.
218      * @param b Neuron.
219      * @throws NoSuchElementException if the neurons do not exist in the
220      * network.
221      */
222     public void addLink(Neuron a,
223                         Neuron b) {
224         final long aId = a.getIdentifier();
225         final long bId = b.getIdentifier();
226 
227         // Check that the neurons belong to this network.
228         if (a != getNeuron(aId)) {
229             throw new NoSuchElementException(Long.toString(aId));
230         }
231         if (b != getNeuron(bId)) {
232             throw new NoSuchElementException(Long.toString(bId));
233         }
234 
235         // Add link from "a" to "b".
236         addLinkToLinkSet(linkMap.get(aId), bId);
237     }
238 
239     /**
240      * Adds a link to neuron {@code id} in given {@code linkSet}.
241      * Note: no check verifies that the identifier indeed belongs
242      * to this network.
243      *
244      * @param linkSet Neuron identifier.
245      * @param id Neuron identifier.
246      */
247     private void addLinkToLinkSet(Set<Long> linkSet,
248                                   long id) {
249         linkSet.add(id);
250     }
251 
252     /**
253      * Deletes the link between neurons {@code a} and {@code b}.
254      *
255      * @param a Neuron.
256      * @param b Neuron.
257      * @throws NoSuchElementException if the neurons do not exist in the
258      * network.
259      */
260     public void deleteLink(Neuron a,
261                            Neuron b) {
262         final long aId = a.getIdentifier();
263         final long bId = b.getIdentifier();
264 
265         // Check that the neurons belong to this network.
266         if (a != getNeuron(aId)) {
267             throw new NoSuchElementException(Long.toString(aId));
268         }
269         if (b != getNeuron(bId)) {
270             throw new NoSuchElementException(Long.toString(bId));
271         }
272 
273         // Delete link from "a" to "b".
274         deleteLinkFromLinkSet(linkMap.get(aId), bId);
275     }
276 
277     /**
278      * Deletes a link to neuron {@code id} in given {@code linkSet}.
279      * Note: no check verifies that the identifier indeed belongs
280      * to this network.
281      *
282      * @param linkSet Neuron identifier.
283      * @param id Neuron identifier.
284      */
285     private void deleteLinkFromLinkSet(Set<Long> linkSet,
286                                        long id) {
287         linkSet.remove(id);
288     }
289 
290     /**
291      * Retrieves the neuron with the given (unique) {@code id}.
292      *
293      * @param id Identifier.
294      * @return the neuron associated with the given {@code id}.
295      * @throws NoSuchElementException if the neuron does not exist in the
296      * network.
297      */
298     public Neuron getNeuron(long id) {
299         final Neuron n = neuronMap.get(id);
300         if (n == null) {
301             throw new NoSuchElementException(Long.toString(id));
302         }
303         return n;
304     }
305 
306     /**
307      * Retrieves the neurons in the neighbourhood of any neuron in the
308      * {@code neurons} list.
309      * @param neurons Neurons for which to retrieve the neighbours.
310      * @return the list of neighbours.
311      * @see #getNeighbours(Iterable,Iterable)
312      */
313     public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) {
314         return getNeighbours(neurons, null);
315     }
316 
317     /**
318      * Retrieves the neurons in the neighbourhood of any neuron in the
319      * {@code neurons} list.
320      * The {@code exclude} list allows to retrieve the "concentric"
321      * neighbourhoods by removing the neurons that belong to the inner
322      * "circles".
323      *
324      * @param neurons Neurons for which to retrieve the neighbours.
325      * @param exclude Neurons to exclude from the returned list.
326      * Can be {@code null}.
327      * @return the list of neighbours.
328      */
329     public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons,
330                                             Iterable<Neuron> exclude) {
331         final Set<Long> idList = new HashSet<Long>();
332 
333         for (Neuron n : neurons) {
334             idList.addAll(linkMap.get(n.getIdentifier()));
335         }
336         if (exclude != null) {
337             for (Neuron n : exclude) {
338                 idList.remove(n.getIdentifier());
339             }
340         }
341 
342         final List<Neuron> neuronList = new ArrayList<Neuron>();
343         for (Long id : idList) {
344             neuronList.add(getNeuron(id));
345         }
346 
347         return neuronList;
348     }
349 
350     /**
351      * Retrieves the neighbours of the given neuron.
352      *
353      * @param neuron Neuron for which to retrieve the neighbours.
354      * @return the list of neighbours.
355      * @see #getNeighbours(Neuron,Iterable)
356      */
357     public Collection<Neuron> getNeighbours(Neuron neuron) {
358         return getNeighbours(neuron, null);
359     }
360 
361     /**
362      * Retrieves the neighbours of the given neuron.
363      *
364      * @param neuron Neuron for which to retrieve the neighbours.
365      * @param exclude Neurons to exclude from the returned list.
366      * Can be {@code null}.
367      * @return the list of neighbours.
368      */
369     public Collection<Neuron> getNeighbours(Neuron neuron,
370                                             Iterable<Neuron> exclude) {
371         final Set<Long> idList = linkMap.get(neuron.getIdentifier());
372         if (exclude != null) {
373             for (Neuron n : exclude) {
374                 idList.remove(n.getIdentifier());
375             }
376         }
377 
378         final List<Neuron> neuronList = new ArrayList<Neuron>();
379         for (Long id : idList) {
380             neuronList.add(getNeuron(id));
381         }
382 
383         return neuronList;
384     }
385 
386     /**
387      * Creates a neuron identifier.
388      *
389      * @return a value that will serve as a unique identifier.
390      */
391     private Long createNextId() {
392         return nextId.getAndIncrement();
393     }
394 
395     /**
396      * Prevents proxy bypass.
397      *
398      * @param in Input stream.
399      */
400     private void readObject(ObjectInputStream in) {
401         throw new IllegalStateException();
402     }
403 
404     /**
405      * Custom serialization.
406      *
407      * @return the proxy instance that will be actually serialized.
408      */
409     private Object writeReplace() {
410         final Neuron[] neuronList = neuronMap.values().toArray(new Neuron[0]);
411         final long[][] neighbourIdList = new long[neuronList.length][];
412 
413         for (int i = 0; i < neuronList.length; i++) {
414             final Collection<Neuron> neighbours = getNeighbours(neuronList[i]);
415             final long[] neighboursId = new long[neighbours.size()];
416             int count = 0;
417             for (Neuron n : neighbours) {
418                 neighboursId[count] = n.getIdentifier();
419                 ++count;
420             }
421             neighbourIdList[i] = neighboursId;
422         }
423 
424         return new SerializationProxy(nextId.get(),
425                                       featureSize,
426                                       neuronList,
427                                       neighbourIdList);
428     }
429 
430     /**
431      * Serialization.
432      */
433     private static class SerializationProxy implements Serializable {
434         /** Serializable. */
435         private static final long serialVersionUID = 20130207L;
436         /** Next identifier. */
437         private final long nextId;
438         /** Number of features. */
439         private final int featureSize;
440         /** Neurons. */
441         private final Neuron[] neuronList;
442         /** Links. */
443         private final long[][] neighbourIdList;
444 
445         /**
446          * @param nextId Next available identifier.
447          * @param featureSize Number of features.
448          * @param neuronList Neurons.
449          * @param neighbourIdList Links associated to each of the neurons in
450          * {@code neuronList}.
451          */
452         SerializationProxy(long nextId,
453                            int featureSize,
454                            Neuron[] neuronList,
455                            long[][] neighbourIdList) {
456             this.nextId = nextId;
457             this.featureSize = featureSize;
458             this.neuronList = neuronList;
459             this.neighbourIdList = neighbourIdList;
460         }
461 
462         /**
463          * Custom serialization.
464          *
465          * @return the {@link Network} for which this instance is the proxy.
466          */
467         private Object readResolve() {
468             return new Network(nextId,
469                                featureSize,
470                                neuronList,
471                                neighbourIdList);
472         }
473     }
474 }