View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math3.ml.neuralnet.twod;
19  
20  import java.util.List;
21  import java.util.ArrayList;
22  import java.io.Serializable;
23  import java.io.ObjectInputStream;
24  import org.apache.commons.math3.ml.neuralnet.Neuron;
25  import org.apache.commons.math3.ml.neuralnet.Network;
26  import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
27  import org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood;
28  import org.apache.commons.math3.exception.NumberIsTooSmallException;
29  import org.apache.commons.math3.exception.OutOfRangeException;
30  import org.apache.commons.math3.exception.MathInternalError;
31  
32  /**
33   * Neural network with the topology of a two-dimensional surface.
34   * Each neuron defines one surface element.
35   * <br/>
36   * This network is primarily intended to represent a
37   * <a href="http://en.wikipedia.org/wiki/Kohonen">
38   *  Self Organizing Feature Map</a>.
39   *
40   * @see org.apache.commons.math3.ml.neuralnet.sofm
41   * @since 3.3
42   */
43  public class NeuronSquareMesh2D implements Serializable {
44      /** Serial version ID */
45      private static final long serialVersionUID = 1L;
46      /** Underlying network. */
47      private final Network network;
48      /** Number of rows. */
49      private final int numberOfRows;
50      /** Number of columns. */
51      private final int numberOfColumns;
52      /** Wrap. */
53      private final boolean wrapRows;
54      /** Wrap. */
55      private final boolean wrapColumns;
56      /** Neighbourhood type. */
57      private final SquareNeighbourhood neighbourhood;
58      /**
59       * Mapping of the 2D coordinates (in the rectangular mesh) to
60       * the neuron identifiers (attributed by the {@link #network}
61       * instance).
62       */
63      private final long[][] identifiers;
64  
65      /**
66       * Constructor with restricted access, solely used for deserialization.
67       *
68       * @param wrapRowDim Whether to wrap the first dimension (i.e the first
69       * and last neurons will be linked together).
70       * @param wrapColDim Whether to wrap the second dimension (i.e the first
71       * and last neurons will be linked together).
72       * @param neighbourhoodType Neighbourhood type.
73       * @param featuresList Arrays that will initialize the features sets of
74       * the network's neurons.
75       * @throws NumberIsTooSmallException if {@code numRows < 2} or
76       * {@code numCols < 2}.
77       */
78      NeuronSquareMesh2D(boolean wrapRowDim,
79                         boolean wrapColDim,
80                         SquareNeighbourhood neighbourhoodType,
81                         double[][][] featuresList) {
82          numberOfRows = featuresList.length;
83          numberOfColumns = featuresList[0].length;
84  
85          if (numberOfRows < 2) {
86              throw new NumberIsTooSmallException(numberOfRows, 2, true);
87          }
88          if (numberOfColumns < 2) {
89              throw new NumberIsTooSmallException(numberOfColumns, 2, true);
90          }
91  
92          wrapRows = wrapRowDim;
93          wrapColumns = wrapColDim;
94          neighbourhood = neighbourhoodType;
95  
96          final int fLen = featuresList[0][0].length;
97          network = new Network(0, fLen);
98          identifiers = new long[numberOfRows][numberOfColumns];
99  
100         // Add neurons.
101         for (int i = 0; i < numberOfRows; i++) {
102             for (int j = 0; j < numberOfColumns; j++) {
103                 identifiers[i][j] = network.createNeuron(featuresList[i][j]);
104             }
105         }
106 
107         // Add links.
108         createLinks();
109     }
110 
111     /**
112      * Creates a two-dimensional network composed of square cells:
113      * Each neuron not located on the border of the mesh has four
114      * neurons linked to it.
115      * <br/>
116      * The links are bi-directional.
117      * <br/>
118      * The topology of the network can also be a cylinder (if one
119      * of the dimensions is wrapped) or a torus (if both dimensions
120      * are wrapped).
121      *
122      * @param numRows Number of neurons in the first dimension.
123      * @param wrapRowDim Whether to wrap the first dimension (i.e the first
124      * and last neurons will be linked together).
125      * @param numCols Number of neurons in the second dimension.
126      * @param wrapColDim Whether to wrap the second dimension (i.e the first
127      * and last neurons will be linked together).
128      * @param neighbourhoodType Neighbourhood type.
129      * @param featureInit Array of functions that will initialize the
130      * corresponding element of the features set of each newly created
131      * neuron. In particular, the size of this array defines the size of
132      * feature set.
133      * @throws NumberIsTooSmallException if {@code numRows < 2} or
134      * {@code numCols < 2}.
135      */
136     public NeuronSquareMesh2D(int numRows,
137                               boolean wrapRowDim,
138                               int numCols,
139                               boolean wrapColDim,
140                               SquareNeighbourhood neighbourhoodType,
141                               FeatureInitializer[] featureInit) {
142         if (numRows < 2) {
143             throw new NumberIsTooSmallException(numRows, 2, true);
144         }
145         if (numCols < 2) {
146             throw new NumberIsTooSmallException(numCols, 2, true);
147         }
148 
149         numberOfRows = numRows;
150         wrapRows = wrapRowDim;
151         numberOfColumns = numCols;
152         wrapColumns = wrapColDim;
153         neighbourhood = neighbourhoodType;
154         identifiers = new long[numberOfRows][numberOfColumns];
155 
156         final int fLen = featureInit.length;
157         network = new Network(0, fLen);
158 
159         // Add neurons.
160         for (int i = 0; i < numRows; i++) {
161             for (int j = 0; j < numCols; j++) {
162                 final double[] features = new double[fLen];
163                 for (int fIndex = 0; fIndex < fLen; fIndex++) {
164                     features[fIndex] = featureInit[fIndex].value();
165                 }
166                 identifiers[i][j] = network.createNeuron(features);
167             }
168         }
169 
170         // Add links.
171         createLinks();
172     }
173 
174     /**
175      * Retrieves the underlying network.
176      * A reference is returned (enabling, for example, the network to be
177      * trained).
178      * This also implies that calling methods that modify the {@link Network}
179      * topology may cause this class to become inconsistent.
180      *
181      * @return the network.
182      */
183     public Network getNetwork() {
184         return network;
185     }
186 
187     /**
188      * Gets the number of neurons in each row of this map.
189      *
190      * @return the number of rows.
191      */
192     public int getNumberOfRows() {
193         return numberOfRows;
194     }
195 
196     /**
197      * Gets the number of neurons in each column of this map.
198      *
199      * @return the number of column.
200      */
201     public int getNumberOfColumns() {
202         return numberOfColumns;
203     }
204 
205     /**
206      * Retrieves the neuron at location {@code (i, j)} in the map.
207      *
208      * @param i Row index.
209      * @param j Column index.
210      * @return the neuron at {@code (i, j)}.
211      * @throws OutOfRangeException if {@code i} or {@code j} is
212      * out of range.
213      */
214     public Neuron getNeuron(int i,
215                             int j) {
216         if (i < 0 ||
217             i >= numberOfRows) {
218             throw new OutOfRangeException(i, 0, numberOfRows - 1);
219         }
220         if (j < 0 ||
221             j >= numberOfColumns) {
222             throw new OutOfRangeException(j, 0, numberOfColumns - 1);
223         }
224 
225         return network.getNeuron(identifiers[i][j]);
226     }
227 
228     /**
229      * Creates the neighbour relationships between neurons.
230      */
231     private void createLinks() {
232         // "linkEnd" will store the identifiers of the "neighbours".
233         final List<Long> linkEnd = new ArrayList<Long>();
234         final int iLast = numberOfRows - 1;
235         final int jLast = numberOfColumns - 1;
236         for (int i = 0; i < numberOfRows; i++) {
237             for (int j = 0; j < numberOfColumns; j++) {
238                 linkEnd.clear();
239 
240                 switch (neighbourhood) {
241 
242                 case MOORE:
243                     // Add links to "diagonal" neighbours.
244                     if (i > 0) {
245                         if (j > 0) {
246                             linkEnd.add(identifiers[i - 1][j - 1]);
247                         }
248                         if (j < jLast) {
249                             linkEnd.add(identifiers[i - 1][j + 1]);
250                         }
251                     }
252                     if (i < iLast) {
253                         if (j > 0) {
254                             linkEnd.add(identifiers[i + 1][j - 1]);
255                         }
256                         if (j < jLast) {
257                             linkEnd.add(identifiers[i + 1][j + 1]);
258                         }
259                     }
260                     if (wrapRows) {
261                         if (i == 0) {
262                             if (j > 0) {
263                                 linkEnd.add(identifiers[iLast][j - 1]);
264                             }
265                             if (j < jLast) {
266                                 linkEnd.add(identifiers[iLast][j + 1]);
267                             }
268                         } else if (i == iLast) {
269                             if (j > 0) {
270                                 linkEnd.add(identifiers[0][j - 1]);
271                             }
272                             if (j < jLast) {
273                                 linkEnd.add(identifiers[0][j + 1]);
274                             }
275                         }
276                     }
277                     if (wrapColumns) {
278                         if (j == 0) {
279                             if (i > 0) {
280                                 linkEnd.add(identifiers[i - 1][jLast]);
281                             }
282                             if (i < iLast) {
283                                 linkEnd.add(identifiers[i + 1][jLast]);
284                             }
285                         } else if (j == jLast) {
286                              if (i > 0) {
287                                  linkEnd.add(identifiers[i - 1][0]);
288                              }
289                              if (i < iLast) {
290                                  linkEnd.add(identifiers[i + 1][0]);
291                              }
292                         }
293                     }
294                     if (wrapRows &&
295                         wrapColumns) {
296                         if (i == 0 &&
297                             j == 0) {
298                             linkEnd.add(identifiers[iLast][jLast]);
299                         } else if (i == 0 &&
300                                    j == jLast) {
301                             linkEnd.add(identifiers[iLast][0]);
302                         } else if (i == iLast &&
303                                    j == 0) {
304                             linkEnd.add(identifiers[0][jLast]);
305                         } else if (i == iLast &&
306                                    j == jLast) {
307                             linkEnd.add(identifiers[0][0]);
308                         }
309                     }
310 
311                     // Case falls through since the "Moore" neighbourhood
312                     // also contains the neurons that belong to the "Von
313                     // Neumann" neighbourhood.
314 
315                     // fallthru (CheckStyle)
316                 case VON_NEUMANN:
317                     // Links to preceding and following "row".
318                     if (i > 0) {
319                         linkEnd.add(identifiers[i - 1][j]);
320                     }
321                     if (i < iLast) {
322                         linkEnd.add(identifiers[i + 1][j]);
323                     }
324                     if (wrapRows) {
325                         if (i == 0) {
326                             linkEnd.add(identifiers[iLast][j]);
327                         } else if (i == iLast) {
328                             linkEnd.add(identifiers[0][j]);
329                         }
330                     }
331 
332                     // Links to preceding and following "column".
333                     if (j > 0) {
334                         linkEnd.add(identifiers[i][j - 1]);
335                     }
336                     if (j < jLast) {
337                         linkEnd.add(identifiers[i][j + 1]);
338                     }
339                     if (wrapColumns) {
340                         if (j == 0) {
341                             linkEnd.add(identifiers[i][jLast]);
342                         } else if (j == jLast) {
343                             linkEnd.add(identifiers[i][0]);
344                         }
345                     }
346                     break;
347 
348                 default:
349                     throw new MathInternalError(); // Cannot happen.
350                 }
351 
352                 final Neuron aNeuron = network.getNeuron(identifiers[i][j]);
353                 for (long b : linkEnd) {
354                     final Neuron bNeuron = network.getNeuron(b);
355                     // Link to all neighbours.
356                     // The reverse links will be added as the loop proceeds.
357                     network.addLink(aNeuron, bNeuron);
358                 }
359             }
360         }
361     }
362 
363     /**
364      * Prevents proxy bypass.
365      *
366      * @param in Input stream.
367      */
368     private void readObject(ObjectInputStream in) {
369         throw new IllegalStateException();
370     }
371 
372     /**
373      * Custom serialization.
374      *
375      * @return the proxy instance that will be actually serialized.
376      */
377     private Object writeReplace() {
378         final double[][][] featuresList = new double[numberOfRows][numberOfColumns][];
379         for (int i = 0; i < numberOfRows; i++) {
380             for (int j = 0; j < numberOfColumns; j++) {
381                 featuresList[i][j] = getNeuron(i, j).getFeatures();
382             }
383         }
384 
385         return new SerializationProxy(wrapRows,
386                                       wrapColumns,
387                                       neighbourhood,
388                                       featuresList);
389     }
390 
391     /**
392      * Serialization.
393      */
394     private static class SerializationProxy implements Serializable {
395         /** Serializable. */
396         private static final long serialVersionUID = 20130226L;
397         /** Wrap. */
398         private final boolean wrapRows;
399         /** Wrap. */
400         private final boolean wrapColumns;
401         /** Neighbourhood type. */
402         private final SquareNeighbourhood neighbourhood;
403         /** Neurons' features. */
404         private final double[][][] featuresList;
405 
406         /**
407          * @param wrapRows Whether the row dimension is wrapped.
408          * @param wrapColumns Whether the column dimension is wrapped.
409          * @param neighbourhood Neighbourhood type.
410          * @param featuresList List of neurons features.
411          * {@code neuronList}.
412          */
413         SerializationProxy(boolean wrapRows,
414                            boolean wrapColumns,
415                            SquareNeighbourhood neighbourhood,
416                            double[][][] featuresList) {
417             this.wrapRows = wrapRows;
418             this.wrapColumns = wrapColumns;
419             this.neighbourhood = neighbourhood;
420             this.featuresList = featuresList;
421         }
422 
423         /**
424          * Custom serialization.
425          *
426          * @return the {@link Neuron} for which this instance is the proxy.
427          */
428         private Object readResolve() {
429             return new NeuronSquareMesh2D(wrapRows,
430                                           wrapColumns,
431                                           neighbourhood,
432                                           featuresList);
433         }
434     }
435 }