001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math3.fitting; 018 019import java.util.ArrayList; 020import java.util.Collection; 021import java.util.Collections; 022import java.util.Comparator; 023import java.util.List; 024 025import org.apache.commons.math3.analysis.function.Gaussian; 026import org.apache.commons.math3.exception.NotStrictlyPositiveException; 027import org.apache.commons.math3.exception.NullArgumentException; 028import org.apache.commons.math3.exception.NumberIsTooSmallException; 029import org.apache.commons.math3.exception.OutOfRangeException; 030import org.apache.commons.math3.exception.ZeroException; 031import org.apache.commons.math3.exception.util.LocalizedFormats; 032import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder; 033import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem; 034import org.apache.commons.math3.linear.DiagonalMatrix; 035import org.apache.commons.math3.util.FastMath; 036 037/** 038 * Fits points to a {@link 039 * org.apache.commons.math3.analysis.function.Gaussian.Parametric Gaussian} 040 * function. 041 * <br/> 042 * The {@link #withStartPoint(double[]) initial guess values} must be passed 043 * in the following order: 044 * <ul> 045 * <li>Normalization</li> 046 * <li>Mean</li> 047 * <li>Sigma</li> 048 * </ul> 049 * The optimal values will be returned in the same order. 050 * 051 * <p> 052 * Usage example: 053 * <pre> 054 * WeightedObservedPoints obs = new WeightedObservedPoints(); 055 * obs.add(4.0254623, 531026.0); 056 * obs.add(4.03128248, 984167.0); 057 * obs.add(4.03839603, 1887233.0); 058 * obs.add(4.04421621, 2687152.0); 059 * obs.add(4.05132976, 3461228.0); 060 * obs.add(4.05326982, 3580526.0); 061 * obs.add(4.05779662, 3439750.0); 062 * obs.add(4.0636168, 2877648.0); 063 * obs.add(4.06943698, 2175960.0); 064 * obs.add(4.07525716, 1447024.0); 065 * obs.add(4.08237071, 717104.0); 066 * obs.add(4.08366408, 620014.0); 067 * double[] parameters = GaussianCurveFitter.create().fit(obs.toList()); 068 * </pre> 069 * 070 * @since 3.3 071 */ 072public class GaussianCurveFitter extends AbstractCurveFitter { 073 /** Parametric function to be fitted. */ 074 private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() { 075 @Override 076 public double value(double x, double ... p) { 077 double v = Double.POSITIVE_INFINITY; 078 try { 079 v = super.value(x, p); 080 } catch (NotStrictlyPositiveException e) { // NOPMD 081 // Do nothing. 082 } 083 return v; 084 } 085 086 @Override 087 public double[] gradient(double x, double ... p) { 088 double[] v = { Double.POSITIVE_INFINITY, 089 Double.POSITIVE_INFINITY, 090 Double.POSITIVE_INFINITY }; 091 try { 092 v = super.gradient(x, p); 093 } catch (NotStrictlyPositiveException e) { // NOPMD 094 // Do nothing. 095 } 096 return v; 097 } 098 }; 099 /** Initial guess. */ 100 private final double[] initialGuess; 101 /** Maximum number of iterations of the optimization algorithm. */ 102 private final int maxIter; 103 104 /** 105 * Contructor used by the factory methods. 106 * 107 * @param initialGuess Initial guess. If set to {@code null}, the initial guess 108 * will be estimated using the {@link ParameterGuesser}. 109 * @param maxIter Maximum number of iterations of the optimization algorithm. 110 */ 111 private GaussianCurveFitter(double[] initialGuess, 112 int maxIter) { 113 this.initialGuess = initialGuess; 114 this.maxIter = maxIter; 115 } 116 117 /** 118 * Creates a default curve fitter. 119 * The initial guess for the parameters will be {@link ParameterGuesser} 120 * computed automatically, and the maximum number of iterations of the 121 * optimization algorithm is set to {@link Integer#MAX_VALUE}. 122 * 123 * @return a curve fitter. 124 * 125 * @see #withStartPoint(double[]) 126 * @see #withMaxIterations(int) 127 */ 128 public static GaussianCurveFitter create() { 129 return new GaussianCurveFitter(null, Integer.MAX_VALUE); 130 } 131 132 /** 133 * Configure the start point (initial guess). 134 * @param newStart new start point (initial guess) 135 * @return a new instance. 136 */ 137 public GaussianCurveFitter withStartPoint(double[] newStart) { 138 return new GaussianCurveFitter(newStart.clone(), 139 maxIter); 140 } 141 142 /** 143 * Configure the maximum number of iterations. 144 * @param newMaxIter maximum number of iterations 145 * @return a new instance. 146 */ 147 public GaussianCurveFitter withMaxIterations(int newMaxIter) { 148 return new GaussianCurveFitter(initialGuess, 149 newMaxIter); 150 } 151 152 /** {@inheritDoc} */ 153 @Override 154 protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) { 155 156 // Prepare least-squares problem. 157 final int len = observations.size(); 158 final double[] target = new double[len]; 159 final double[] weights = new double[len]; 160 161 int i = 0; 162 for (WeightedObservedPoint obs : observations) { 163 target[i] = obs.getY(); 164 weights[i] = obs.getWeight(); 165 ++i; 166 } 167 168 final AbstractCurveFitter.TheoreticalValuesFunction model = 169 new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations); 170 171 final double[] startPoint = initialGuess != null ? 172 initialGuess : 173 // Compute estimation. 174 new ParameterGuesser(observations).guess(); 175 176 // Return a new least squares problem set up to fit a Gaussian curve to the 177 // observed points. 178 return new LeastSquaresBuilder(). 179 maxEvaluations(Integer.MAX_VALUE). 180 maxIterations(maxIter). 181 start(startPoint). 182 target(target). 183 weight(new DiagonalMatrix(weights)). 184 model(model.getModelFunction(), model.getModelFunctionJacobian()). 185 build(); 186 187 } 188 189 /** 190 * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma} 191 * of a {@link org.apache.commons.math3.analysis.function.Gaussian.Parametric} 192 * based on the specified observed points. 193 */ 194 public static class ParameterGuesser { 195 /** Normalization factor. */ 196 private final double norm; 197 /** Mean. */ 198 private final double mean; 199 /** Standard deviation. */ 200 private final double sigma; 201 202 /** 203 * Constructs instance with the specified observed points. 204 * 205 * @param observations Observed points from which to guess the 206 * parameters of the Gaussian. 207 * @throws NullArgumentException if {@code observations} is 208 * {@code null}. 209 * @throws NumberIsTooSmallException if there are less than 3 210 * observations. 211 */ 212 public ParameterGuesser(Collection<WeightedObservedPoint> observations) { 213 if (observations == null) { 214 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); 215 } 216 if (observations.size() < 3) { 217 throw new NumberIsTooSmallException(observations.size(), 3, true); 218 } 219 220 final List<WeightedObservedPoint> sorted = sortObservations(observations); 221 final double[] params = basicGuess(sorted.toArray(new WeightedObservedPoint[0])); 222 223 norm = params[0]; 224 mean = params[1]; 225 sigma = params[2]; 226 } 227 228 /** 229 * Gets an estimation of the parameters. 230 * 231 * @return the guessed parameters, in the following order: 232 * <ul> 233 * <li>Normalization factor</li> 234 * <li>Mean</li> 235 * <li>Standard deviation</li> 236 * </ul> 237 */ 238 public double[] guess() { 239 return new double[] { norm, mean, sigma }; 240 } 241 242 /** 243 * Sort the observations. 244 * 245 * @param unsorted Input observations. 246 * @return the input observations, sorted. 247 */ 248 private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) { 249 final List<WeightedObservedPoint> observations = new ArrayList<WeightedObservedPoint>(unsorted); 250 251 final Comparator<WeightedObservedPoint> cmp = new Comparator<WeightedObservedPoint>() { 252 public int compare(WeightedObservedPoint p1, 253 WeightedObservedPoint p2) { 254 if (p1 == null && p2 == null) { 255 return 0; 256 } 257 if (p1 == null) { 258 return -1; 259 } 260 if (p2 == null) { 261 return 1; 262 } 263 if (p1.getX() < p2.getX()) { 264 return -1; 265 } 266 if (p1.getX() > p2.getX()) { 267 return 1; 268 } 269 if (p1.getY() < p2.getY()) { 270 return -1; 271 } 272 if (p1.getY() > p2.getY()) { 273 return 1; 274 } 275 if (p1.getWeight() < p2.getWeight()) { 276 return -1; 277 } 278 if (p1.getWeight() > p2.getWeight()) { 279 return 1; 280 } 281 return 0; 282 } 283 }; 284 285 Collections.sort(observations, cmp); 286 return observations; 287 } 288 289 /** 290 * Guesses the parameters based on the specified observed points. 291 * 292 * @param points Observed points, sorted. 293 * @return the guessed parameters (normalization factor, mean and 294 * sigma). 295 */ 296 private double[] basicGuess(WeightedObservedPoint[] points) { 297 final int maxYIdx = findMaxY(points); 298 final double n = points[maxYIdx].getY(); 299 final double m = points[maxYIdx].getX(); 300 301 double fwhmApprox; 302 try { 303 final double halfY = n + ((m - n) / 2); 304 final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY); 305 final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY); 306 fwhmApprox = fwhmX2 - fwhmX1; 307 } catch (OutOfRangeException e) { 308 // TODO: Exceptions should not be used for flow control. 309 fwhmApprox = points[points.length - 1].getX() - points[0].getX(); 310 } 311 final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2))); 312 313 return new double[] { n, m, s }; 314 } 315 316 /** 317 * Finds index of point in specified points with the largest Y. 318 * 319 * @param points Points to search. 320 * @return the index in specified points array. 321 */ 322 private int findMaxY(WeightedObservedPoint[] points) { 323 int maxYIdx = 0; 324 for (int i = 1; i < points.length; i++) { 325 if (points[i].getY() > points[maxYIdx].getY()) { 326 maxYIdx = i; 327 } 328 } 329 return maxYIdx; 330 } 331 332 /** 333 * Interpolates using the specified points to determine X at the 334 * specified Y. 335 * 336 * @param points Points to use for interpolation. 337 * @param startIdx Index within points from which to start the search for 338 * interpolation bounds points. 339 * @param idxStep Index step for searching interpolation bounds points. 340 * @param y Y value for which X should be determined. 341 * @return the value of X for the specified Y. 342 * @throws ZeroException if {@code idxStep} is 0. 343 * @throws OutOfRangeException if specified {@code y} is not within the 344 * range of the specified {@code points}. 345 */ 346 private double interpolateXAtY(WeightedObservedPoint[] points, 347 int startIdx, 348 int idxStep, 349 double y) 350 throws OutOfRangeException { 351 if (idxStep == 0) { 352 throw new ZeroException(); 353 } 354 final WeightedObservedPoint[] twoPoints 355 = getInterpolationPointsForY(points, startIdx, idxStep, y); 356 final WeightedObservedPoint p1 = twoPoints[0]; 357 final WeightedObservedPoint p2 = twoPoints[1]; 358 if (p1.getY() == y) { 359 return p1.getX(); 360 } 361 if (p2.getY() == y) { 362 return p2.getX(); 363 } 364 return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) / 365 (p2.getY() - p1.getY())); 366 } 367 368 /** 369 * Gets the two bounding interpolation points from the specified points 370 * suitable for determining X at the specified Y. 371 * 372 * @param points Points to use for interpolation. 373 * @param startIdx Index within points from which to start search for 374 * interpolation bounds points. 375 * @param idxStep Index step for search for interpolation bounds points. 376 * @param y Y value for which X should be determined. 377 * @return the array containing two points suitable for determining X at 378 * the specified Y. 379 * @throws ZeroException if {@code idxStep} is 0. 380 * @throws OutOfRangeException if specified {@code y} is not within the 381 * range of the specified {@code points}. 382 */ 383 private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points, 384 int startIdx, 385 int idxStep, 386 double y) 387 throws OutOfRangeException { 388 if (idxStep == 0) { 389 throw new ZeroException(); 390 } 391 for (int i = startIdx; 392 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length; 393 i += idxStep) { 394 final WeightedObservedPoint p1 = points[i]; 395 final WeightedObservedPoint p2 = points[i + idxStep]; 396 if (isBetween(y, p1.getY(), p2.getY())) { 397 if (idxStep < 0) { 398 return new WeightedObservedPoint[] { p2, p1 }; 399 } else { 400 return new WeightedObservedPoint[] { p1, p2 }; 401 } 402 } 403 } 404 405 // Boundaries are replaced by dummy values because the raised 406 // exception is caught and the message never displayed. 407 // TODO: Exceptions should not be used for flow control. 408 throw new OutOfRangeException(y, 409 Double.NEGATIVE_INFINITY, 410 Double.POSITIVE_INFINITY); 411 } 412 413 /** 414 * Determines whether a value is between two other values. 415 * 416 * @param value Value to test whether it is between {@code boundary1} 417 * and {@code boundary2}. 418 * @param boundary1 One end of the range. 419 * @param boundary2 Other end of the range. 420 * @return {@code true} if {@code value} is between {@code boundary1} and 421 * {@code boundary2} (inclusive), {@code false} otherwise. 422 */ 423 private boolean isBetween(double value, 424 double boundary1, 425 double boundary2) { 426 return (value >= boundary1 && value <= boundary2) || 427 (value >= boundary2 && value <= boundary1); 428 } 429 } 430}