Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3106_AnalyticContinuedFractionsRegression/HeuristicLab.Algorithms.DataAnalysis/3.4/ContinuedFractionRegression/NelderMeadSimplex.cs @ 17985

Last change on this file since 17985 was 17984, checked in by gkronber, 3 years ago

#3106 updated implementation based on the reply by Moscato

File size: 10.7 KB
Line 
1using System;
2
3namespace HeuristicLab.Algorithms.DataAnalysis.ContinuedFractionRegression {
4  public enum TerminationReason {
5    MaxFunctionEvaluations,
6    Converged,
7    Unspecified
8  }
9
10  public delegate double ObjectiveFunctionDelegate(double[] constants);
11
12
13  public sealed class NelderMeadSimplex {
14    private static readonly double JITTER = 1e-10d;           // a small value used to protect against floating point noise
15
16    public static RegressionResult Regress(SimplexConstant[] simplexConstants, double convergenceTolerance, int maxEvaluations,
17                                    ObjectiveFunctionDelegate objectiveFunction) {
18      // confirm that we are in a position to commence
19      if (objectiveFunction == null)
20        throw new InvalidOperationException("ObjectiveFunction must be set to a valid ObjectiveFunctionDelegate");
21
22      if (simplexConstants == null)
23        throw new InvalidOperationException("SimplexConstants must be initialized");
24
25      // create the initial simplex
26      int numDimensions = simplexConstants.Length;
27      int numVertices = numDimensions + 1;
28      Vector[] vertices = _initializeVertices(simplexConstants);
29      double[] errorValues = new double[numVertices];
30
31      int evaluationCount = 0;
32      TerminationReason terminationReason = TerminationReason.Unspecified;
33      ErrorProfile errorProfile;
34
35      errorValues = _initializeErrorValues(vertices, objectiveFunction);
36
37      // iterate until we converge, or complete our permitted number of iterations
38      while (true) {
39        errorProfile = _evaluateSimplex(errorValues);
40
41        // see if the range in point heights is small enough to exit
42        if (_hasConverged(convergenceTolerance, errorProfile, errorValues)) {
43          terminationReason = TerminationReason.Converged;
44          break;
45        }
46
47        // attempt a reflection of the simplex
48        double reflectionPointValue = _tryToScaleSimplex(-1.0, ref errorProfile, vertices, errorValues, objectiveFunction);
49        ++evaluationCount;
50        if (reflectionPointValue <= errorValues[errorProfile.LowestIndex]) {
51          // it's better than the best point, so attempt an expansion of the simplex
52          double expansionPointValue = _tryToScaleSimplex(2.0, ref errorProfile, vertices, errorValues, objectiveFunction);
53          ++evaluationCount;
54        } else if (reflectionPointValue >= errorValues[errorProfile.NextHighestIndex]) {
55          // it would be worse than the second best point, so attempt a contraction to look
56          // for an intermediate point
57          double currentWorst = errorValues[errorProfile.HighestIndex];
58          double contractionPointValue = _tryToScaleSimplex(0.5, ref errorProfile, vertices, errorValues, objectiveFunction);
59          ++evaluationCount;
60          if (contractionPointValue >= currentWorst) {
61            // that would be even worse, so let's try to contract uniformly towards the low point;
62            // don't bother to update the error profile, we'll do it at the start of the
63            // next iteration
64            _shrinkSimplex(errorProfile, vertices, errorValues, objectiveFunction);
65            evaluationCount += numVertices; // that required one function evaluation for each vertex; keep track
66          }
67        }
68        // check to see if we have exceeded our alloted number of evaluations
69        if (evaluationCount >= maxEvaluations) {
70          terminationReason = TerminationReason.MaxFunctionEvaluations;
71          break;
72        }
73      }
74      RegressionResult regressionResult = new RegressionResult(terminationReason,
75                          vertices[errorProfile.LowestIndex].Components, errorValues[errorProfile.LowestIndex], evaluationCount);
76      return regressionResult;
77    }
78
79    /// <summary>
80    /// Evaluate the objective function at each vertex to create a corresponding
81    /// list of error values for each vertex
82    /// </summary>
83    /// <param name="vertices"></param>
84    /// <returns></returns>
85    private static double[] _initializeErrorValues(Vector[] vertices, ObjectiveFunctionDelegate objectiveFunction) {
86      double[] errorValues = new double[vertices.Length];
87      for (int i = 0; i < vertices.Length; i++) {
88        errorValues[i] = objectiveFunction(vertices[i].Components);
89      }
90      return errorValues;
91    }
92
93    /// <summary>
94    /// Check whether the points in the error profile have so little range that we
95    /// consider ourselves to have converged
96    /// </summary>
97    /// <param name="errorProfile"></param>
98    /// <param name="errorValues"></param>
99    /// <returns></returns>
100    private static bool _hasConverged(double convergenceTolerance, ErrorProfile errorProfile, double[] errorValues) {
101      double range = 2 * Math.Abs(errorValues[errorProfile.HighestIndex] - errorValues[errorProfile.LowestIndex]) /
102          (Math.Abs(errorValues[errorProfile.HighestIndex]) + Math.Abs(errorValues[errorProfile.LowestIndex]) + JITTER);
103
104      if (range < convergenceTolerance) {
105        return true;
106      } else {
107        return false;
108      }
109    }
110
111    /// <summary>
112    /// Examine all error values to determine the ErrorProfile
113    /// </summary>
114    /// <param name="errorValues"></param>
115    /// <returns></returns>
116    private static ErrorProfile _evaluateSimplex(double[] errorValues) {
117      ErrorProfile errorProfile = new ErrorProfile();
118      if (errorValues[0] > errorValues[1]) {
119        errorProfile.HighestIndex = 0;
120        errorProfile.NextHighestIndex = 1;
121      } else {
122        errorProfile.HighestIndex = 1;
123        errorProfile.NextHighestIndex = 0;
124      }
125
126      for (int index = 0; index < errorValues.Length; index++) {
127        double errorValue = errorValues[index];
128        if (errorValue <= errorValues[errorProfile.LowestIndex]) {
129          errorProfile.LowestIndex = index;
130        }
131        if (errorValue > errorValues[errorProfile.HighestIndex]) {
132          errorProfile.NextHighestIndex = errorProfile.HighestIndex; // downgrade the current highest to next highest
133          errorProfile.HighestIndex = index;
134        } else if (errorValue > errorValues[errorProfile.NextHighestIndex] && index != errorProfile.HighestIndex) {
135          errorProfile.NextHighestIndex = index;
136        }
137      }
138
139      return errorProfile;
140    }
141
142    /// <summary>
143    /// Construct an initial simplex, given starting guesses for the constants, and
144    /// initial step sizes for each dimension
145    /// </summary>
146    /// <param name="simplexConstants"></param>
147    /// <returns></returns>
148    private static Vector[] _initializeVertices(SimplexConstant[] simplexConstants) {
149      int numDimensions = simplexConstants.Length;
150      Vector[] vertices = new Vector[numDimensions + 1];
151
152      // define one point of the simplex as the given initial guesses
153      Vector p0 = new Vector(numDimensions);
154      for (int i = 0; i < numDimensions; i++) {
155        p0[i] = simplexConstants[i].Value;
156      }
157
158      // now fill in the vertices, creating the additional points as:
159      // P(i) = P(0) + Scale(i) * UnitVector(i)
160      vertices[0] = p0;
161      for (int i = 0; i < numDimensions; i++) {
162        double scale = simplexConstants[i].InitialPerturbation;
163        Vector unitVector = new Vector(numDimensions);
164        unitVector[i] = 1;
165        vertices[i + 1] = p0.Add(unitVector.Multiply(scale));
166      }
167      return vertices;
168    }
169
170    /// <summary>
171    /// Test a scaling operation of the high point, and replace it if it is an improvement
172    /// </summary>
173    /// <param name="scaleFactor"></param>
174    /// <param name="errorProfile"></param>
175    /// <param name="vertices"></param>
176    /// <param name="errorValues"></param>
177    /// <returns></returns>
178    private static double _tryToScaleSimplex(double scaleFactor, ref ErrorProfile errorProfile, Vector[] vertices,
179                                      double[] errorValues, ObjectiveFunctionDelegate objectiveFunction) {
180      // find the centroid through which we will reflect
181      Vector centroid = _computeCentroid(vertices, errorProfile);
182
183      // define the vector from the centroid to the high point
184      Vector centroidToHighPoint = vertices[errorProfile.HighestIndex].Subtract(centroid);
185
186      // scale and position the vector to determine the new trial point
187      Vector newPoint = centroidToHighPoint.Multiply(scaleFactor).Add(centroid);
188
189      // evaluate the new point
190      double newErrorValue = objectiveFunction(newPoint.Components);
191
192      // if it's better, replace the old high point
193      if (newErrorValue < errorValues[errorProfile.HighestIndex]) {
194        vertices[errorProfile.HighestIndex] = newPoint;
195        errorValues[errorProfile.HighestIndex] = newErrorValue;
196      }
197
198      return newErrorValue;
199    }
200
201    /// <summary>
202    /// Contract the simplex uniformly around the lowest point
203    /// </summary>
204    /// <param name="errorProfile"></param>
205    /// <param name="vertices"></param>
206    /// <param name="errorValues"></param>
207    private static void _shrinkSimplex(ErrorProfile errorProfile, Vector[] vertices, double[] errorValues,
208                                  ObjectiveFunctionDelegate objectiveFunction) {
209      Vector lowestVertex = vertices[errorProfile.LowestIndex];
210      for (int i = 0; i < vertices.Length; i++) {
211        if (i != errorProfile.LowestIndex) {
212          vertices[i] = (vertices[i].Add(lowestVertex)).Multiply(0.5);
213          errorValues[i] = objectiveFunction(vertices[i].Components);
214        }
215      }
216    }
217
218    /// <summary>
219    /// Compute the centroid of all points except the worst
220    /// </summary>
221    /// <param name="vertices"></param>
222    /// <param name="errorProfile"></param>
223    /// <returns></returns>
224    private static Vector _computeCentroid(Vector[] vertices, ErrorProfile errorProfile) {
225      int numVertices = vertices.Length;
226      // find the centroid of all points except the worst one
227      Vector centroid = new Vector(numVertices - 1);
228      for (int i = 0; i < numVertices; i++) {
229        if (i != errorProfile.HighestIndex) {
230          centroid = centroid.Add(vertices[i]);
231        }
232      }
233      return centroid.Multiply(1.0d / (numVertices - 1));
234    }
235
236    private sealed class ErrorProfile {
237      private int _highestIndex;
238      private int _nextHighestIndex;
239      private int _lowestIndex;
240
241      public int HighestIndex {
242        get { return _highestIndex; }
243        set { _highestIndex = value; }
244      }
245
246      public int NextHighestIndex {
247        get { return _nextHighestIndex; }
248        set { _nextHighestIndex = value; }
249      }
250
251      public int LowestIndex {
252        get { return _lowestIndex; }
253        set { _lowestIndex = value; }
254      }
255    }
256  }
257}
Note: See TracBrowser for help on using the repository browser.