1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math3.ml.neuralnet;
19
20 import java.io.Serializable;
21 import java.io.ObjectInputStream;
22 import java.util.NoSuchElementException;
23 import java.util.List;
24 import java.util.ArrayList;
25 import java.util.Set;
26 import java.util.HashSet;
27 import java.util.Collection;
28 import java.util.Iterator;
29 import java.util.Comparator;
30 import java.util.Collections;
31 import java.util.concurrent.ConcurrentHashMap;
32 import java.util.concurrent.atomic.AtomicLong;
33 import org.apache.commons.math3.exception.DimensionMismatchException;
34 import org.apache.commons.math3.exception.MathIllegalStateException;
35
36
37
38
39
40
41
42
43
44
45 public class Network
46 implements Iterable<Neuron>,
47 Serializable {
48
49 private static final long serialVersionUID = 20130207L;
50
51 private final ConcurrentHashMap<Long, Neuron> neuronMap
52 = new ConcurrentHashMap<Long, Neuron>();
53
54 private final AtomicLong nextId;
55
56 private final int featureSize;
57
58 private final ConcurrentHashMap<Long, Set<Long>> linkMap
59 = new ConcurrentHashMap<Long, Set<Long>>();
60
61
62
63
64
65 public static class NeuronIdentifierComparator
66 implements Comparator<Neuron>,
67 Serializable {
68
69 private static final long serialVersionUID = 20130207L;
70
71
72 public int compare(Neuron a,
73 Neuron b) {
74 final long aId = a.getIdentifier();
75 final long bId = b.getIdentifier();
76 return aId < bId ? -1 :
77 aId > bId ? 1 : 0;
78 }
79 }
80
81
82
83
84
85
86
87
88
89
90
91
92 Network(long nextId,
93 int featureSize,
94 Neuron[] neuronList,
95 long[][] neighbourIdList) {
96 final int numNeurons = neuronList.length;
97 if (numNeurons != neighbourIdList.length) {
98 throw new MathIllegalStateException();
99 }
100
101 for (int i = 0; i < numNeurons; i++) {
102 final Neuron n = neuronList[i];
103 final long id = n.getIdentifier();
104 if (id >= nextId) {
105 throw new MathIllegalStateException();
106 }
107 neuronMap.put(id, n);
108 linkMap.put(id, new HashSet<Long>());
109 }
110
111 for (int i = 0; i < numNeurons; i++) {
112 final long aId = neuronList[i].getIdentifier();
113 final Set<Long> aLinks = linkMap.get(aId);
114 for (Long bId : neighbourIdList[i]) {
115 if (neuronMap.get(bId) == null) {
116 throw new MathIllegalStateException();
117 }
118 addLinkToLinkSet(aLinks, bId);
119 }
120 }
121
122 this.nextId = new AtomicLong(nextId);
123 this.featureSize = featureSize;
124 }
125
126
127
128
129
130
131 public Network(long initialIdentifier,
132 int featureSize) {
133 nextId = new AtomicLong(initialIdentifier);
134 this.featureSize = featureSize;
135 }
136
137
138
139
140 public Iterator<Neuron> iterator() {
141 return neuronMap.values().iterator();
142 }
143
144
145
146
147
148
149
150
151
152 public Collection<Neuron> getNeurons(Comparator<Neuron> comparator) {
153 final List<Neuron> neurons = new ArrayList<Neuron>();
154 neurons.addAll(neuronMap.values());
155
156 Collections.sort(neurons, comparator);
157
158 return neurons;
159 }
160
161
162
163
164
165
166
167
168
169
170 public long createNeuron(double[] features) {
171 if (features.length != featureSize) {
172 throw new DimensionMismatchException(features.length, featureSize);
173 }
174
175 final long id = createNextId();
176 neuronMap.put(id, new Neuron(id, features));
177 linkMap.put(id, new HashSet<Long>());
178 return id;
179 }
180
181
182
183
184
185
186
187
188
189
190 public void deleteNeuron(Neuron neuron) {
191 final Collection<Neuron> neighbours = getNeighbours(neuron);
192
193
194 for (Neuron n : neighbours) {
195 deleteLink(n, neuron);
196 }
197
198
199 neuronMap.remove(neuron.getIdentifier());
200 }
201
202
203
204
205
206
207 public int getFeaturesSize() {
208 return featureSize;
209 }
210
211
212
213
214
215
216
217
218
219
220
221
222 public void addLink(Neuron a,
223 Neuron b) {
224 final long aId = a.getIdentifier();
225 final long bId = b.getIdentifier();
226
227
228 if (a != getNeuron(aId)) {
229 throw new NoSuchElementException(Long.toString(aId));
230 }
231 if (b != getNeuron(bId)) {
232 throw new NoSuchElementException(Long.toString(bId));
233 }
234
235
236 addLinkToLinkSet(linkMap.get(aId), bId);
237 }
238
239
240
241
242
243
244
245
246
247 private void addLinkToLinkSet(Set<Long> linkSet,
248 long id) {
249 linkSet.add(id);
250 }
251
252
253
254
255
256
257
258
259
260 public void deleteLink(Neuron a,
261 Neuron b) {
262 final long aId = a.getIdentifier();
263 final long bId = b.getIdentifier();
264
265
266 if (a != getNeuron(aId)) {
267 throw new NoSuchElementException(Long.toString(aId));
268 }
269 if (b != getNeuron(bId)) {
270 throw new NoSuchElementException(Long.toString(bId));
271 }
272
273
274 deleteLinkFromLinkSet(linkMap.get(aId), bId);
275 }
276
277
278
279
280
281
282
283
284
285 private void deleteLinkFromLinkSet(Set<Long> linkSet,
286 long id) {
287 linkSet.remove(id);
288 }
289
290
291
292
293
294
295
296
297
298 public Neuron getNeuron(long id) {
299 final Neuron n = neuronMap.get(id);
300 if (n == null) {
301 throw new NoSuchElementException(Long.toString(id));
302 }
303 return n;
304 }
305
306
307
308
309
310
311
312
313 public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) {
314 return getNeighbours(neurons, null);
315 }
316
317
318
319
320
321
322
323
324
325
326
327
328
329 public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons,
330 Iterable<Neuron> exclude) {
331 final Set<Long> idList = new HashSet<Long>();
332
333 for (Neuron n : neurons) {
334 idList.addAll(linkMap.get(n.getIdentifier()));
335 }
336 if (exclude != null) {
337 for (Neuron n : exclude) {
338 idList.remove(n.getIdentifier());
339 }
340 }
341
342 final List<Neuron> neuronList = new ArrayList<Neuron>();
343 for (Long id : idList) {
344 neuronList.add(getNeuron(id));
345 }
346
347 return neuronList;
348 }
349
350
351
352
353
354
355
356
357 public Collection<Neuron> getNeighbours(Neuron neuron) {
358 return getNeighbours(neuron, null);
359 }
360
361
362
363
364
365
366
367
368
369 public Collection<Neuron> getNeighbours(Neuron neuron,
370 Iterable<Neuron> exclude) {
371 final Set<Long> idList = linkMap.get(neuron.getIdentifier());
372 if (exclude != null) {
373 for (Neuron n : exclude) {
374 idList.remove(n.getIdentifier());
375 }
376 }
377
378 final List<Neuron> neuronList = new ArrayList<Neuron>();
379 for (Long id : idList) {
380 neuronList.add(getNeuron(id));
381 }
382
383 return neuronList;
384 }
385
386
387
388
389
390
391 private Long createNextId() {
392 return nextId.getAndIncrement();
393 }
394
395
396
397
398
399
400 private void readObject(ObjectInputStream in) {
401 throw new IllegalStateException();
402 }
403
404
405
406
407
408
409 private Object writeReplace() {
410 final Neuron[] neuronList = neuronMap.values().toArray(new Neuron[0]);
411 final long[][] neighbourIdList = new long[neuronList.length][];
412
413 for (int i = 0; i < neuronList.length; i++) {
414 final Collection<Neuron> neighbours = getNeighbours(neuronList[i]);
415 final long[] neighboursId = new long[neighbours.size()];
416 int count = 0;
417 for (Neuron n : neighbours) {
418 neighboursId[count] = n.getIdentifier();
419 ++count;
420 }
421 neighbourIdList[i] = neighboursId;
422 }
423
424 return new SerializationProxy(nextId.get(),
425 featureSize,
426 neuronList,
427 neighbourIdList);
428 }
429
430
431
432
433 private static class SerializationProxy implements Serializable {
434
435 private static final long serialVersionUID = 20130207L;
436
437 private final long nextId;
438
439 private final int featureSize;
440
441 private final Neuron[] neuronList;
442
443 private final long[][] neighbourIdList;
444
445
446
447
448
449
450
451
452 SerializationProxy(long nextId,
453 int featureSize,
454 Neuron[] neuronList,
455 long[][] neighbourIdList) {
456 this.nextId = nextId;
457 this.featureSize = featureSize;
458 this.neuronList = neuronList;
459 this.neighbourIdList = neighbourIdList;
460 }
461
462
463
464
465
466
467 private Object readResolve() {
468 return new Network(nextId,
469 featureSize,
470 neuronList,
471 neighbourIdList);
472 }
473 }
474 }