1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math3.ml.neuralnet.twod;
19
20 import java.util.List;
21 import java.util.ArrayList;
22 import java.io.Serializable;
23 import java.io.ObjectInputStream;
24 import org.apache.commons.math3.ml.neuralnet.Neuron;
25 import org.apache.commons.math3.ml.neuralnet.Network;
26 import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
27 import org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood;
28 import org.apache.commons.math3.exception.NumberIsTooSmallException;
29 import org.apache.commons.math3.exception.OutOfRangeException;
30 import org.apache.commons.math3.exception.MathInternalError;
31
32
33
34
35
36
37
38
39
40
41
42
43 public class NeuronSquareMesh2D implements Serializable {
44
45 private static final long serialVersionUID = 1L;
46
47 private final Network network;
48
49 private final int numberOfRows;
50
51 private final int numberOfColumns;
52
53 private final boolean wrapRows;
54
55 private final boolean wrapColumns;
56
57 private final SquareNeighbourhood neighbourhood;
58
59
60
61
62
63 private final long[][] identifiers;
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 NeuronSquareMesh2D(boolean wrapRowDim,
79 boolean wrapColDim,
80 SquareNeighbourhood neighbourhoodType,
81 double[][][] featuresList) {
82 numberOfRows = featuresList.length;
83 numberOfColumns = featuresList[0].length;
84
85 if (numberOfRows < 2) {
86 throw new NumberIsTooSmallException(numberOfRows, 2, true);
87 }
88 if (numberOfColumns < 2) {
89 throw new NumberIsTooSmallException(numberOfColumns, 2, true);
90 }
91
92 wrapRows = wrapRowDim;
93 wrapColumns = wrapColDim;
94 neighbourhood = neighbourhoodType;
95
96 final int fLen = featuresList[0][0].length;
97 network = new Network(0, fLen);
98 identifiers = new long[numberOfRows][numberOfColumns];
99
100
101 for (int i = 0; i < numberOfRows; i++) {
102 for (int j = 0; j < numberOfColumns; j++) {
103 identifiers[i][j] = network.createNeuron(featuresList[i][j]);
104 }
105 }
106
107
108 createLinks();
109 }
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136 public NeuronSquareMesh2D(int numRows,
137 boolean wrapRowDim,
138 int numCols,
139 boolean wrapColDim,
140 SquareNeighbourhood neighbourhoodType,
141 FeatureInitializer[] featureInit) {
142 if (numRows < 2) {
143 throw new NumberIsTooSmallException(numRows, 2, true);
144 }
145 if (numCols < 2) {
146 throw new NumberIsTooSmallException(numCols, 2, true);
147 }
148
149 numberOfRows = numRows;
150 wrapRows = wrapRowDim;
151 numberOfColumns = numCols;
152 wrapColumns = wrapColDim;
153 neighbourhood = neighbourhoodType;
154 identifiers = new long[numberOfRows][numberOfColumns];
155
156 final int fLen = featureInit.length;
157 network = new Network(0, fLen);
158
159
160 for (int i = 0; i < numRows; i++) {
161 for (int j = 0; j < numCols; j++) {
162 final double[] features = new double[fLen];
163 for (int fIndex = 0; fIndex < fLen; fIndex++) {
164 features[fIndex] = featureInit[fIndex].value();
165 }
166 identifiers[i][j] = network.createNeuron(features);
167 }
168 }
169
170
171 createLinks();
172 }
173
174
175
176
177
178
179
180
181
182
183 public Network getNetwork() {
184 return network;
185 }
186
187
188
189
190
191
192 public int getNumberOfRows() {
193 return numberOfRows;
194 }
195
196
197
198
199
200
201 public int getNumberOfColumns() {
202 return numberOfColumns;
203 }
204
205
206
207
208
209
210
211
212
213
214 public Neuron getNeuron(int i,
215 int j) {
216 if (i < 0 ||
217 i >= numberOfRows) {
218 throw new OutOfRangeException(i, 0, numberOfRows - 1);
219 }
220 if (j < 0 ||
221 j >= numberOfColumns) {
222 throw new OutOfRangeException(j, 0, numberOfColumns - 1);
223 }
224
225 return network.getNeuron(identifiers[i][j]);
226 }
227
228
229
230
231 private void createLinks() {
232
233 final List<Long> linkEnd = new ArrayList<Long>();
234 final int iLast = numberOfRows - 1;
235 final int jLast = numberOfColumns - 1;
236 for (int i = 0; i < numberOfRows; i++) {
237 for (int j = 0; j < numberOfColumns; j++) {
238 linkEnd.clear();
239
240 switch (neighbourhood) {
241
242 case MOORE:
243
244 if (i > 0) {
245 if (j > 0) {
246 linkEnd.add(identifiers[i - 1][j - 1]);
247 }
248 if (j < jLast) {
249 linkEnd.add(identifiers[i - 1][j + 1]);
250 }
251 }
252 if (i < iLast) {
253 if (j > 0) {
254 linkEnd.add(identifiers[i + 1][j - 1]);
255 }
256 if (j < jLast) {
257 linkEnd.add(identifiers[i + 1][j + 1]);
258 }
259 }
260 if (wrapRows) {
261 if (i == 0) {
262 if (j > 0) {
263 linkEnd.add(identifiers[iLast][j - 1]);
264 }
265 if (j < jLast) {
266 linkEnd.add(identifiers[iLast][j + 1]);
267 }
268 } else if (i == iLast) {
269 if (j > 0) {
270 linkEnd.add(identifiers[0][j - 1]);
271 }
272 if (j < jLast) {
273 linkEnd.add(identifiers[0][j + 1]);
274 }
275 }
276 }
277 if (wrapColumns) {
278 if (j == 0) {
279 if (i > 0) {
280 linkEnd.add(identifiers[i - 1][jLast]);
281 }
282 if (i < iLast) {
283 linkEnd.add(identifiers[i + 1][jLast]);
284 }
285 } else if (j == jLast) {
286 if (i > 0) {
287 linkEnd.add(identifiers[i - 1][0]);
288 }
289 if (i < iLast) {
290 linkEnd.add(identifiers[i + 1][0]);
291 }
292 }
293 }
294 if (wrapRows &&
295 wrapColumns) {
296 if (i == 0 &&
297 j == 0) {
298 linkEnd.add(identifiers[iLast][jLast]);
299 } else if (i == 0 &&
300 j == jLast) {
301 linkEnd.add(identifiers[iLast][0]);
302 } else if (i == iLast &&
303 j == 0) {
304 linkEnd.add(identifiers[0][jLast]);
305 } else if (i == iLast &&
306 j == jLast) {
307 linkEnd.add(identifiers[0][0]);
308 }
309 }
310
311
312
313
314
315
316 case VON_NEUMANN:
317
318 if (i > 0) {
319 linkEnd.add(identifiers[i - 1][j]);
320 }
321 if (i < iLast) {
322 linkEnd.add(identifiers[i + 1][j]);
323 }
324 if (wrapRows) {
325 if (i == 0) {
326 linkEnd.add(identifiers[iLast][j]);
327 } else if (i == iLast) {
328 linkEnd.add(identifiers[0][j]);
329 }
330 }
331
332
333 if (j > 0) {
334 linkEnd.add(identifiers[i][j - 1]);
335 }
336 if (j < jLast) {
337 linkEnd.add(identifiers[i][j + 1]);
338 }
339 if (wrapColumns) {
340 if (j == 0) {
341 linkEnd.add(identifiers[i][jLast]);
342 } else if (j == jLast) {
343 linkEnd.add(identifiers[i][0]);
344 }
345 }
346 break;
347
348 default:
349 throw new MathInternalError();
350 }
351
352 final Neuron aNeuron = network.getNeuron(identifiers[i][j]);
353 for (long b : linkEnd) {
354 final Neuron bNeuron = network.getNeuron(b);
355
356
357 network.addLink(aNeuron, bNeuron);
358 }
359 }
360 }
361 }
362
363
364
365
366
367
368 private void readObject(ObjectInputStream in) {
369 throw new IllegalStateException();
370 }
371
372
373
374
375
376
377 private Object writeReplace() {
378 final double[][][] featuresList = new double[numberOfRows][numberOfColumns][];
379 for (int i = 0; i < numberOfRows; i++) {
380 for (int j = 0; j < numberOfColumns; j++) {
381 featuresList[i][j] = getNeuron(i, j).getFeatures();
382 }
383 }
384
385 return new SerializationProxy(wrapRows,
386 wrapColumns,
387 neighbourhood,
388 featuresList);
389 }
390
391
392
393
394 private static class SerializationProxy implements Serializable {
395
396 private static final long serialVersionUID = 20130226L;
397
398 private final boolean wrapRows;
399
400 private final boolean wrapColumns;
401
402 private final SquareNeighbourhood neighbourhood;
403
404 private final double[][][] featuresList;
405
406
407
408
409
410
411
412
413 SerializationProxy(boolean wrapRows,
414 boolean wrapColumns,
415 SquareNeighbourhood neighbourhood,
416 double[][][] featuresList) {
417 this.wrapRows = wrapRows;
418 this.wrapColumns = wrapColumns;
419 this.neighbourhood = neighbourhood;
420 this.featuresList = featuresList;
421 }
422
423
424
425
426
427
428 private Object readResolve() {
429 return new NeuronSquareMesh2D(wrapRows,
430 wrapColumns,
431 neighbourhood,
432 featuresList);
433 }
434 }
435 }