source: branches/3106_AnalyticContinuedFractionsRegression/HeuristicLab.Algorithms.DataAnalysis/3.4/ContinuedFractionRegression/NelderMeadSimplex.cs @ 17848

Last change on this file since 17848 was 17848, checked in by gkronber, 5 months ago

#3106 initial import of code (translated from HL script)

File size: 10.8 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Threading.Tasks;
6
7namespace HeuristicLab.Algorithms.DataAnalysis.ContinuedFractionRegression {
8  public enum TerminationReason {
9    MaxFunctionEvaluations,
10    Converged,
11    Unspecified
12  }
13
14  public delegate double ObjectiveFunctionDelegate(double[] constants);
15
16
17  public sealed class NelderMeadSimplex {
18    private static readonly double JITTER = 1e-10d;           // a small value used to protect against floating point noise
19
20    public static RegressionResult Regress(SimplexConstant[] simplexConstants, double convergenceTolerance, int maxEvaluations,
21                                    ObjectiveFunctionDelegate objectiveFunction) {
22      // confirm that we are in a position to commence
23      if (objectiveFunction == null)
24        throw new InvalidOperationException("ObjectiveFunction must be set to a valid ObjectiveFunctionDelegate");
25
26      if (simplexConstants == null)
27        throw new InvalidOperationException("SimplexConstants must be initialized");
28
29      // create the initial simplex
30      int numDimensions = simplexConstants.Length;
31      int numVertices = numDimensions + 1;
32      Vector[] vertices = _initializeVertices(simplexConstants);
33      double[] errorValues = new double[numVertices];
34
35      int evaluationCount = 0;
36      TerminationReason terminationReason = TerminationReason.Unspecified;
37      ErrorProfile errorProfile;
38
39      errorValues = _initializeErrorValues(vertices, objectiveFunction);
40
41      // iterate until we converge, or complete our permitted number of iterations
42      while (true) {
43        errorProfile = _evaluateSimplex(errorValues);
44
45        // see if the range in point heights is small enough to exit
46        if (_hasConverged(convergenceTolerance, errorProfile, errorValues)) {
47          terminationReason = TerminationReason.Converged;
48          break;
49        }
50
51        // attempt a reflection of the simplex
52        double reflectionPointValue = _tryToScaleSimplex(-1.0, ref errorProfile, vertices, errorValues, objectiveFunction);
53        ++evaluationCount;
54        if (reflectionPointValue <= errorValues[errorProfile.LowestIndex]) {
55          // it's better than the best point, so attempt an expansion of the simplex
56          double expansionPointValue = _tryToScaleSimplex(2.0, ref errorProfile, vertices, errorValues, objectiveFunction);
57          ++evaluationCount;
58        } else if (reflectionPointValue >= errorValues[errorProfile.NextHighestIndex]) {
59          // it would be worse than the second best point, so attempt a contraction to look
60          // for an intermediate point
61          double currentWorst = errorValues[errorProfile.HighestIndex];
62          double contractionPointValue = _tryToScaleSimplex(0.5, ref errorProfile, vertices, errorValues, objectiveFunction);
63          ++evaluationCount;
64          if (contractionPointValue >= currentWorst) {
65            // that would be even worse, so let's try to contract uniformly towards the low point;
66            // don't bother to update the error profile, we'll do it at the start of the
67            // next iteration
68            _shrinkSimplex(errorProfile, vertices, errorValues, objectiveFunction);
69            evaluationCount += numVertices; // that required one function evaluation for each vertex; keep track
70          }
71        }
72        // check to see if we have exceeded our alloted number of evaluations
73        if (evaluationCount >= maxEvaluations) {
74          terminationReason = TerminationReason.MaxFunctionEvaluations;
75          break;
76        }
77      }
78      RegressionResult regressionResult = new RegressionResult(terminationReason,
79                          vertices[errorProfile.LowestIndex].Components, errorValues[errorProfile.LowestIndex], evaluationCount);
80      return regressionResult;
81    }
82
83    /// <summary>
84    /// Evaluate the objective function at each vertex to create a corresponding
85    /// list of error values for each vertex
86    /// </summary>
87    /// <param name="vertices"></param>
88    /// <returns></returns>
89    private static double[] _initializeErrorValues(Vector[] vertices, ObjectiveFunctionDelegate objectiveFunction) {
90      double[] errorValues = new double[vertices.Length];
91      for (int i = 0; i < vertices.Length; i++) {
92        errorValues[i] = objectiveFunction(vertices[i].Components);
93      }
94      return errorValues;
95    }
96
97    /// <summary>
98    /// Check whether the points in the error profile have so little range that we
99    /// consider ourselves to have converged
100    /// </summary>
101    /// <param name="errorProfile"></param>
102    /// <param name="errorValues"></param>
103    /// <returns></returns>
104    private static bool _hasConverged(double convergenceTolerance, ErrorProfile errorProfile, double[] errorValues) {
105      double range = 2 * Math.Abs(errorValues[errorProfile.HighestIndex] - errorValues[errorProfile.LowestIndex]) /
106          (Math.Abs(errorValues[errorProfile.HighestIndex]) + Math.Abs(errorValues[errorProfile.LowestIndex]) + JITTER);
107
108      if (range < convergenceTolerance) {
109        return true;
110      } else {
111        return false;
112      }
113    }
114
115    /// <summary>
116    /// Examine all error values to determine the ErrorProfile
117    /// </summary>
118    /// <param name="errorValues"></param>
119    /// <returns></returns>
120    private static ErrorProfile _evaluateSimplex(double[] errorValues) {
121      ErrorProfile errorProfile = new ErrorProfile();
122      if (errorValues[0] > errorValues[1]) {
123        errorProfile.HighestIndex = 0;
124        errorProfile.NextHighestIndex = 1;
125      } else {
126        errorProfile.HighestIndex = 1;
127        errorProfile.NextHighestIndex = 0;
128      }
129
130      for (int index = 0; index < errorValues.Length; index++) {
131        double errorValue = errorValues[index];
132        if (errorValue <= errorValues[errorProfile.LowestIndex]) {
133          errorProfile.LowestIndex = index;
134        }
135        if (errorValue > errorValues[errorProfile.HighestIndex]) {
136          errorProfile.NextHighestIndex = errorProfile.HighestIndex; // downgrade the current highest to next highest
137          errorProfile.HighestIndex = index;
138        } else if (errorValue > errorValues[errorProfile.NextHighestIndex] && index != errorProfile.HighestIndex) {
139          errorProfile.NextHighestIndex = index;
140        }
141      }
142
143      return errorProfile;
144    }
145
146    /// <summary>
147    /// Construct an initial simplex, given starting guesses for the constants, and
148    /// initial step sizes for each dimension
149    /// </summary>
150    /// <param name="simplexConstants"></param>
151    /// <returns></returns>
152    private static Vector[] _initializeVertices(SimplexConstant[] simplexConstants) {
153      int numDimensions = simplexConstants.Length;
154      Vector[] vertices = new Vector[numDimensions + 1];
155
156      // define one point of the simplex as the given initial guesses
157      Vector p0 = new Vector(numDimensions);
158      for (int i = 0; i < numDimensions; i++) {
159        p0[i] = simplexConstants[i].Value;
160      }
161
162      // now fill in the vertices, creating the additional points as:
163      // P(i) = P(0) + Scale(i) * UnitVector(i)
164      vertices[0] = p0;
165      for (int i = 0; i < numDimensions; i++) {
166        double scale = simplexConstants[i].InitialPerturbation;
167        Vector unitVector = new Vector(numDimensions);
168        unitVector[i] = 1;
169        vertices[i + 1] = p0.Add(unitVector.Multiply(scale));
170      }
171      return vertices;
172    }
173
174    /// <summary>
175    /// Test a scaling operation of the high point, and replace it if it is an improvement
176    /// </summary>
177    /// <param name="scaleFactor"></param>
178    /// <param name="errorProfile"></param>
179    /// <param name="vertices"></param>
180    /// <param name="errorValues"></param>
181    /// <returns></returns>
182    private static double _tryToScaleSimplex(double scaleFactor, ref ErrorProfile errorProfile, Vector[] vertices,
183                                      double[] errorValues, ObjectiveFunctionDelegate objectiveFunction) {
184      // find the centroid through which we will reflect
185      Vector centroid = _computeCentroid(vertices, errorProfile);
186
187      // define the vector from the centroid to the high point
188      Vector centroidToHighPoint = vertices[errorProfile.HighestIndex].Subtract(centroid);
189
190      // scale and position the vector to determine the new trial point
191      Vector newPoint = centroidToHighPoint.Multiply(scaleFactor).Add(centroid);
192
193      // evaluate the new point
194      double newErrorValue = objectiveFunction(newPoint.Components);
195
196      // if it's better, replace the old high point
197      if (newErrorValue < errorValues[errorProfile.HighestIndex]) {
198        vertices[errorProfile.HighestIndex] = newPoint;
199        errorValues[errorProfile.HighestIndex] = newErrorValue;
200      }
201
202      return newErrorValue;
203    }
204
205    /// <summary>
206    /// Contract the simplex uniformly around the lowest point
207    /// </summary>
208    /// <param name="errorProfile"></param>
209    /// <param name="vertices"></param>
210    /// <param name="errorValues"></param>
211    private static void _shrinkSimplex(ErrorProfile errorProfile, Vector[] vertices, double[] errorValues,
212                                  ObjectiveFunctionDelegate objectiveFunction) {
213      Vector lowestVertex = vertices[errorProfile.LowestIndex];
214      for (int i = 0; i < vertices.Length; i++) {
215        if (i != errorProfile.LowestIndex) {
216          vertices[i] = (vertices[i].Add(lowestVertex)).Multiply(0.5);
217          errorValues[i] = objectiveFunction(vertices[i].Components);
218        }
219      }
220    }
221
222    /// <summary>
223    /// Compute the centroid of all points except the worst
224    /// </summary>
225    /// <param name="vertices"></param>
226    /// <param name="errorProfile"></param>
227    /// <returns></returns>
228    private static Vector _computeCentroid(Vector[] vertices, ErrorProfile errorProfile) {
229      int numVertices = vertices.Length;
230      // find the centroid of all points except the worst one
231      Vector centroid = new Vector(numVertices - 1);
232      for (int i = 0; i < numVertices; i++) {
233        if (i != errorProfile.HighestIndex) {
234          centroid = centroid.Add(vertices[i]);
235        }
236      }
237      return centroid.Multiply(1.0d / (numVertices - 1));
238    }
239
240    private sealed class ErrorProfile {
241      private int _highestIndex;
242      private int _nextHighestIndex;
243      private int _lowestIndex;
244
245      public int HighestIndex {
246        get { return _highestIndex; }
247        set { _highestIndex = value; }
248      }
249
250      public int NextHighestIndex {
251        get { return _nextHighestIndex; }
252        set { _nextHighestIndex = value; }
253      }
254
255      public int LowestIndex {
256        get { return _lowestIndex; }
257        set { _lowestIndex = value; }
258      }
259    }
260  }
261}
Note: See TracBrowser for help on using the repository browser.