View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math3.distribution;
18  
19  import java.io.Serializable;
20  import java.lang.reflect.Array;
21  import java.util.ArrayList;
22  import java.util.Arrays;
23  import java.util.List;
24  
25  import org.apache.commons.math3.exception.MathArithmeticException;
26  import org.apache.commons.math3.exception.NotANumberException;
27  import org.apache.commons.math3.exception.NotFiniteNumberException;
28  import org.apache.commons.math3.exception.NotPositiveException;
29  import org.apache.commons.math3.exception.NotStrictlyPositiveException;
30  import org.apache.commons.math3.exception.NullArgumentException;
31  import org.apache.commons.math3.exception.util.LocalizedFormats;
32  import org.apache.commons.math3.random.RandomGenerator;
33  import org.apache.commons.math3.random.Well19937c;
34  import org.apache.commons.math3.util.MathArrays;
35  import org.apache.commons.math3.util.Pair;
36  
37  /**
38   * <p>A generic implementation of a
39   * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
40   * discrete probability distribution (Wikipedia)</a> over a finite sample space,
41   * based on an enumerated list of &lt;value, probability&gt; pairs.  Input probabilities must all be non-negative,
42   * but zero values are allowed and their sum does not have to equal one. Constructors will normalize input
43   * probabilities to make them sum to one.</p>
44   *
45   * <p>The list of <value, probability> pairs does not, strictly speaking, have to be a function and it can
46   * contain null values.  The pmf created by the constructor will combine probabilities of equal values and
47   * will treat null values as equal.  For example, if the list of pairs &lt;"dog", 0.2&gt;, &lt;null, 0.1&gt;,
48   * &lt;"pig", 0.2&gt;, &lt;"dog", 0.1&gt;, &lt;null, 0.4&gt; is provided to the constructor, the resulting
49   * pmf will assign mass of 0.5 to null, 0.3 to "dog" and 0.2 to null.</p>
50   *
51   * @param <T> type of the elements in the sample space.
52   * @since 3.2
53   */
54  public class EnumeratedDistribution<T> implements Serializable {
55  
56      /** Serializable UID. */
57      private static final long serialVersionUID = 20123308L;
58  
59      /**
60       * RNG instance used to generate samples from the distribution.
61       */
62      protected final RandomGenerator random;
63  
64      /**
65       * List of random variable values.
66       */
67      private final List<T> singletons;
68  
69      /**
70       * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
71       * probability[i] is the probability that a random variable following this distribution takes
72       * the value singletons[i].
73       */
74      private final double[] probabilities;
75  
76      /**
77       * Cumulative probabilities, cached to speed up sampling.
78       */
79      private final double[] cumulativeProbabilities;
80  
81      /**
82       * Create an enumerated distribution using the given probability mass function
83       * enumeration.
84       * <p>
85       * <b>Note:</b> this constructor will implicitly create an instance of
86       * {@link Well19937c} as random generator to be used for sampling only (see
87       * {@link #sample()} and {@link #sample(int)}). In case no sampling is
88       * needed for the created distribution, it is advised to pass {@code null}
89       * as random generator via the appropriate constructors to avoid the
90       * additional initialisation overhead.
91       *
92       * @param pmf probability mass function enumerated as a list of <T, probability>
93       * pairs.
94       * @throws NotPositiveException if any of the probabilities are negative.
95       * @throws NotFiniteNumberException if any of the probabilities are infinite.
96       * @throws NotANumberException if any of the probabilities are NaN.
97       * @throws MathArithmeticException all of the probabilities are 0.
98       */
99      public EnumeratedDistribution(final List<Pair<T, Double>> pmf)
100         throws NotPositiveException, MathArithmeticException, NotFiniteNumberException, NotANumberException {
101         this(new Well19937c(), pmf);
102     }
103 
104     /**
105      * Create an enumerated distribution using the given random number generator
106      * and probability mass function enumeration.
107      *
108      * @param rng random number generator.
109      * @param pmf probability mass function enumerated as a list of <T, probability>
110      * pairs.
111      * @throws NotPositiveException if any of the probabilities are negative.
112      * @throws NotFiniteNumberException if any of the probabilities are infinite.
113      * @throws NotANumberException if any of the probabilities are NaN.
114      * @throws MathArithmeticException all of the probabilities are 0.
115      */
116     public EnumeratedDistribution(final RandomGenerator rng, final List<Pair<T, Double>> pmf)
117         throws NotPositiveException, MathArithmeticException, NotFiniteNumberException, NotANumberException {
118         random = rng;
119 
120         singletons = new ArrayList<T>(pmf.size());
121         final double[] probs = new double[pmf.size()];
122 
123         for (int i = 0; i < pmf.size(); i++) {
124             final Pair<T, Double> sample = pmf.get(i);
125             singletons.add(sample.getKey());
126             final double p = sample.getValue();
127             if (p < 0) {
128                 throw new NotPositiveException(sample.getValue());
129             }
130             if (Double.isInfinite(p)) {
131                 throw new NotFiniteNumberException(p);
132             }
133             if (Double.isNaN(p)) {
134                 throw new NotANumberException();
135             }
136             probs[i] = p;
137         }
138 
139         probabilities = MathArrays.normalizeArray(probs, 1.0);
140 
141         cumulativeProbabilities = new double[probabilities.length];
142         double sum = 0;
143         for (int i = 0; i < probabilities.length; i++) {
144             sum += probabilities[i];
145             cumulativeProbabilities[i] = sum;
146         }
147     }
148 
149     /**
150      * Reseed the random generator used to generate samples.
151      *
152      * @param seed the new seed
153      */
154     public void reseedRandomGenerator(long seed) {
155         random.setSeed(seed);
156     }
157 
158     /**
159      * <p>For a random variable {@code X} whose values are distributed according to
160      * this distribution, this method returns {@code P(X = x)}. In other words,
161      * this method represents the probability mass function (PMF) for the
162      * distribution.</p>
163      *
164      * <p>Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)},
165      * or both are null, then {@code probability(x1) = probability(x2)}.</p>
166      *
167      * @param x the point at which the PMF is evaluated
168      * @return the value of the probability mass function at {@code x}
169      */
170     double probability(final T x) {
171         double probability = 0;
172 
173         for (int i = 0; i < probabilities.length; i++) {
174             if ((x == null && singletons.get(i) == null) ||
175                 (x != null && x.equals(singletons.get(i)))) {
176                 probability += probabilities[i];
177             }
178         }
179 
180         return probability;
181     }
182 
183     /**
184      * <p>Return the probability mass function as a list of <value, probability> pairs.</p>
185      *
186      * <p>Note that if duplicate and / or null values were provided to the constructor
187      * when creating this EnumeratedDistribution, the returned list will contain these
188      * values.  If duplicates values exist, what is returned will not represent
189      * a pmf (i.e., it is up to the caller to consolidate duplicate mass points).</p>
190      *
191      * @return the probability mass function.
192      */
193     public List<Pair<T, Double>> getPmf() {
194         final List<Pair<T, Double>> samples = new ArrayList<Pair<T, Double>>(probabilities.length);
195 
196         for (int i = 0; i < probabilities.length; i++) {
197             samples.add(new Pair<T, Double>(singletons.get(i), probabilities[i]));
198         }
199 
200         return samples;
201     }
202 
203     /**
204      * Generate a random value sampled from this distribution.
205      *
206      * @return a random value.
207      */
208     public T sample() {
209         final double randomValue = random.nextDouble();
210 
211         int index = Arrays.binarySearch(cumulativeProbabilities, randomValue);
212         if (index < 0) {
213             index = -index-1;
214         }
215 
216         if (index >= 0 && index < probabilities.length) {
217             if (randomValue < cumulativeProbabilities[index]) {
218                 return singletons.get(index);
219             }
220         }
221 
222         /* This should never happen, but it ensures we will return a correct
223          * object in case there is some floating point inequality problem
224          * wrt the cumulative probabilities. */
225         return singletons.get(singletons.size() - 1);
226     }
227 
228     /**
229      * Generate a random sample from the distribution.
230      *
231      * @param sampleSize the number of random values to generate.
232      * @return an array representing the random sample.
233      * @throws NotStrictlyPositiveException if {@code sampleSize} is not
234      * positive.
235      */
236     public Object[] sample(int sampleSize) throws NotStrictlyPositiveException {
237         if (sampleSize <= 0) {
238             throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
239                     sampleSize);
240         }
241 
242         final Object[] out = new Object[sampleSize];
243 
244         for (int i = 0; i < sampleSize; i++) {
245             out[i] = sample();
246         }
247 
248         return out;
249 
250     }
251 
252     /**
253      * Generate a random sample from the distribution.
254      * <p>
255      * If the requested samples fit in the specified array, it is returned
256      * therein. Otherwise, a new array is allocated with the runtime type of
257      * the specified array and the size of this collection.
258      *
259      * @param sampleSize the number of random values to generate.
260      * @param array the array to populate.
261      * @return an array representing the random sample.
262      * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive.
263      * @throws NullArgumentException if {@code array} is null
264      */
265     public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException {
266         if (sampleSize <= 0) {
267             throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
268         }
269 
270         if (array == null) {
271             throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
272         }
273 
274         T[] out;
275         if (array.length < sampleSize) {
276             @SuppressWarnings("unchecked") // safe as both are of type T
277             final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize);
278             out = unchecked;
279         } else {
280             out = array;
281         }
282 
283         for (int i = 0; i < sampleSize; i++) {
284             out[i] = sample();
285         }
286 
287         return out;
288 
289     }
290 
291 }