1 | using System;
|
---|
2 |
|
---|
3 | namespace HeuristicLab.Algorithms.DataAnalysis.ContinuedFractionRegression {
|
---|
4 | public enum TerminationReason {
|
---|
5 | MaxFunctionEvaluations,
|
---|
6 | Converged,
|
---|
7 | Unspecified
|
---|
8 | }
|
---|
9 |
|
---|
10 | public delegate double ObjectiveFunctionDelegate(double[] constants);
|
---|
11 |
|
---|
12 |
|
---|
13 | public sealed class NelderMeadSimplex {
|
---|
14 | private static readonly double JITTER = 1e-10d; // a small value used to protect against floating point noise
|
---|
15 |
|
---|
16 | public static RegressionResult Regress(SimplexConstant[] simplexConstants, double convergenceTolerance, int maxEvaluations,
|
---|
17 | ObjectiveFunctionDelegate objectiveFunction) {
|
---|
18 | // confirm that we are in a position to commence
|
---|
19 | if (objectiveFunction == null)
|
---|
20 | throw new InvalidOperationException("ObjectiveFunction must be set to a valid ObjectiveFunctionDelegate");
|
---|
21 |
|
---|
22 | if (simplexConstants == null)
|
---|
23 | throw new InvalidOperationException("SimplexConstants must be initialized");
|
---|
24 |
|
---|
25 | // create the initial simplex
|
---|
26 | int numDimensions = simplexConstants.Length;
|
---|
27 | int numVertices = numDimensions + 1;
|
---|
28 | Vector[] vertices = _initializeVertices(simplexConstants);
|
---|
29 | double[] errorValues = new double[numVertices];
|
---|
30 |
|
---|
31 | int evaluationCount = 0;
|
---|
32 | TerminationReason terminationReason = TerminationReason.Unspecified;
|
---|
33 | ErrorProfile errorProfile;
|
---|
34 |
|
---|
35 | errorValues = _initializeErrorValues(vertices, objectiveFunction);
|
---|
36 |
|
---|
37 | // iterate until we converge, or complete our permitted number of iterations
|
---|
38 | while (true) {
|
---|
39 | errorProfile = _evaluateSimplex(errorValues);
|
---|
40 |
|
---|
41 | // see if the range in point heights is small enough to exit
|
---|
42 | if (_hasConverged(convergenceTolerance, errorProfile, errorValues)) {
|
---|
43 | terminationReason = TerminationReason.Converged;
|
---|
44 | break;
|
---|
45 | }
|
---|
46 |
|
---|
47 | // attempt a reflection of the simplex
|
---|
48 | double reflectionPointValue = _tryToScaleSimplex(-1.0, ref errorProfile, vertices, errorValues, objectiveFunction);
|
---|
49 | ++evaluationCount;
|
---|
50 | if (reflectionPointValue <= errorValues[errorProfile.LowestIndex]) {
|
---|
51 | // it's better than the best point, so attempt an expansion of the simplex
|
---|
52 | double expansionPointValue = _tryToScaleSimplex(2.0, ref errorProfile, vertices, errorValues, objectiveFunction);
|
---|
53 | ++evaluationCount;
|
---|
54 | } else if (reflectionPointValue >= errorValues[errorProfile.NextHighestIndex]) {
|
---|
55 | // it would be worse than the second best point, so attempt a contraction to look
|
---|
56 | // for an intermediate point
|
---|
57 | double currentWorst = errorValues[errorProfile.HighestIndex];
|
---|
58 | double contractionPointValue = _tryToScaleSimplex(0.5, ref errorProfile, vertices, errorValues, objectiveFunction);
|
---|
59 | ++evaluationCount;
|
---|
60 | if (contractionPointValue >= currentWorst) {
|
---|
61 | // that would be even worse, so let's try to contract uniformly towards the low point;
|
---|
62 | // don't bother to update the error profile, we'll do it at the start of the
|
---|
63 | // next iteration
|
---|
64 | _shrinkSimplex(errorProfile, vertices, errorValues, objectiveFunction);
|
---|
65 | evaluationCount += numVertices; // that required one function evaluation for each vertex; keep track
|
---|
66 | }
|
---|
67 | }
|
---|
68 | // check to see if we have exceeded our alloted number of evaluations
|
---|
69 | if (evaluationCount >= maxEvaluations) {
|
---|
70 | terminationReason = TerminationReason.MaxFunctionEvaluations;
|
---|
71 | break;
|
---|
72 | }
|
---|
73 | }
|
---|
74 | RegressionResult regressionResult = new RegressionResult(terminationReason,
|
---|
75 | vertices[errorProfile.LowestIndex].Components, errorValues[errorProfile.LowestIndex], evaluationCount);
|
---|
76 | return regressionResult;
|
---|
77 | }
|
---|
78 |
|
---|
79 | /// <summary>
|
---|
80 | /// Evaluate the objective function at each vertex to create a corresponding
|
---|
81 | /// list of error values for each vertex
|
---|
82 | /// </summary>
|
---|
83 | /// <param name="vertices"></param>
|
---|
84 | /// <returns></returns>
|
---|
85 | private static double[] _initializeErrorValues(Vector[] vertices, ObjectiveFunctionDelegate objectiveFunction) {
|
---|
86 | double[] errorValues = new double[vertices.Length];
|
---|
87 | for (int i = 0; i < vertices.Length; i++) {
|
---|
88 | errorValues[i] = objectiveFunction(vertices[i].Components);
|
---|
89 | }
|
---|
90 | return errorValues;
|
---|
91 | }
|
---|
92 |
|
---|
93 | /// <summary>
|
---|
94 | /// Check whether the points in the error profile have so little range that we
|
---|
95 | /// consider ourselves to have converged
|
---|
96 | /// </summary>
|
---|
97 | /// <param name="errorProfile"></param>
|
---|
98 | /// <param name="errorValues"></param>
|
---|
99 | /// <returns></returns>
|
---|
100 | private static bool _hasConverged(double convergenceTolerance, ErrorProfile errorProfile, double[] errorValues) {
|
---|
101 | double range = 2 * Math.Abs(errorValues[errorProfile.HighestIndex] - errorValues[errorProfile.LowestIndex]) /
|
---|
102 | (Math.Abs(errorValues[errorProfile.HighestIndex]) + Math.Abs(errorValues[errorProfile.LowestIndex]) + JITTER);
|
---|
103 |
|
---|
104 | if (range < convergenceTolerance) {
|
---|
105 | return true;
|
---|
106 | } else {
|
---|
107 | return false;
|
---|
108 | }
|
---|
109 | }
|
---|
110 |
|
---|
111 | /// <summary>
|
---|
112 | /// Examine all error values to determine the ErrorProfile
|
---|
113 | /// </summary>
|
---|
114 | /// <param name="errorValues"></param>
|
---|
115 | /// <returns></returns>
|
---|
116 | private static ErrorProfile _evaluateSimplex(double[] errorValues) {
|
---|
117 | ErrorProfile errorProfile = new ErrorProfile();
|
---|
118 | if (errorValues[0] > errorValues[1]) {
|
---|
119 | errorProfile.HighestIndex = 0;
|
---|
120 | errorProfile.NextHighestIndex = 1;
|
---|
121 | } else {
|
---|
122 | errorProfile.HighestIndex = 1;
|
---|
123 | errorProfile.NextHighestIndex = 0;
|
---|
124 | }
|
---|
125 |
|
---|
126 | for (int index = 0; index < errorValues.Length; index++) {
|
---|
127 | double errorValue = errorValues[index];
|
---|
128 | if (errorValue <= errorValues[errorProfile.LowestIndex]) {
|
---|
129 | errorProfile.LowestIndex = index;
|
---|
130 | }
|
---|
131 | if (errorValue > errorValues[errorProfile.HighestIndex]) {
|
---|
132 | errorProfile.NextHighestIndex = errorProfile.HighestIndex; // downgrade the current highest to next highest
|
---|
133 | errorProfile.HighestIndex = index;
|
---|
134 | } else if (errorValue > errorValues[errorProfile.NextHighestIndex] && index != errorProfile.HighestIndex) {
|
---|
135 | errorProfile.NextHighestIndex = index;
|
---|
136 | }
|
---|
137 | }
|
---|
138 |
|
---|
139 | return errorProfile;
|
---|
140 | }
|
---|
141 |
|
---|
142 | /// <summary>
|
---|
143 | /// Construct an initial simplex, given starting guesses for the constants, and
|
---|
144 | /// initial step sizes for each dimension
|
---|
145 | /// </summary>
|
---|
146 | /// <param name="simplexConstants"></param>
|
---|
147 | /// <returns></returns>
|
---|
148 | private static Vector[] _initializeVertices(SimplexConstant[] simplexConstants) {
|
---|
149 | int numDimensions = simplexConstants.Length;
|
---|
150 | Vector[] vertices = new Vector[numDimensions + 1];
|
---|
151 |
|
---|
152 | // define one point of the simplex as the given initial guesses
|
---|
153 | Vector p0 = new Vector(numDimensions);
|
---|
154 | for (int i = 0; i < numDimensions; i++) {
|
---|
155 | p0[i] = simplexConstants[i].Value;
|
---|
156 | }
|
---|
157 |
|
---|
158 | // now fill in the vertices, creating the additional points as:
|
---|
159 | // P(i) = P(0) + Scale(i) * UnitVector(i)
|
---|
160 | vertices[0] = p0;
|
---|
161 | for (int i = 0; i < numDimensions; i++) {
|
---|
162 | double scale = simplexConstants[i].InitialPerturbation;
|
---|
163 | Vector unitVector = new Vector(numDimensions);
|
---|
164 | unitVector[i] = 1;
|
---|
165 | vertices[i + 1] = p0.Add(unitVector.Multiply(scale));
|
---|
166 | }
|
---|
167 | return vertices;
|
---|
168 | }
|
---|
169 |
|
---|
170 | /// <summary>
|
---|
171 | /// Test a scaling operation of the high point, and replace it if it is an improvement
|
---|
172 | /// </summary>
|
---|
173 | /// <param name="scaleFactor"></param>
|
---|
174 | /// <param name="errorProfile"></param>
|
---|
175 | /// <param name="vertices"></param>
|
---|
176 | /// <param name="errorValues"></param>
|
---|
177 | /// <returns></returns>
|
---|
178 | private static double _tryToScaleSimplex(double scaleFactor, ref ErrorProfile errorProfile, Vector[] vertices,
|
---|
179 | double[] errorValues, ObjectiveFunctionDelegate objectiveFunction) {
|
---|
180 | // find the centroid through which we will reflect
|
---|
181 | Vector centroid = _computeCentroid(vertices, errorProfile);
|
---|
182 |
|
---|
183 | // define the vector from the centroid to the high point
|
---|
184 | Vector centroidToHighPoint = vertices[errorProfile.HighestIndex].Subtract(centroid);
|
---|
185 |
|
---|
186 | // scale and position the vector to determine the new trial point
|
---|
187 | Vector newPoint = centroidToHighPoint.Multiply(scaleFactor).Add(centroid);
|
---|
188 |
|
---|
189 | // evaluate the new point
|
---|
190 | double newErrorValue = objectiveFunction(newPoint.Components);
|
---|
191 |
|
---|
192 | // if it's better, replace the old high point
|
---|
193 | if (newErrorValue < errorValues[errorProfile.HighestIndex]) {
|
---|
194 | vertices[errorProfile.HighestIndex] = newPoint;
|
---|
195 | errorValues[errorProfile.HighestIndex] = newErrorValue;
|
---|
196 | }
|
---|
197 |
|
---|
198 | return newErrorValue;
|
---|
199 | }
|
---|
200 |
|
---|
201 | /// <summary>
|
---|
202 | /// Contract the simplex uniformly around the lowest point
|
---|
203 | /// </summary>
|
---|
204 | /// <param name="errorProfile"></param>
|
---|
205 | /// <param name="vertices"></param>
|
---|
206 | /// <param name="errorValues"></param>
|
---|
207 | private static void _shrinkSimplex(ErrorProfile errorProfile, Vector[] vertices, double[] errorValues,
|
---|
208 | ObjectiveFunctionDelegate objectiveFunction) {
|
---|
209 | Vector lowestVertex = vertices[errorProfile.LowestIndex];
|
---|
210 | for (int i = 0; i < vertices.Length; i++) {
|
---|
211 | if (i != errorProfile.LowestIndex) {
|
---|
212 | vertices[i] = (vertices[i].Add(lowestVertex)).Multiply(0.5);
|
---|
213 | errorValues[i] = objectiveFunction(vertices[i].Components);
|
---|
214 | }
|
---|
215 | }
|
---|
216 | }
|
---|
217 |
|
---|
218 | /// <summary>
|
---|
219 | /// Compute the centroid of all points except the worst
|
---|
220 | /// </summary>
|
---|
221 | /// <param name="vertices"></param>
|
---|
222 | /// <param name="errorProfile"></param>
|
---|
223 | /// <returns></returns>
|
---|
224 | private static Vector _computeCentroid(Vector[] vertices, ErrorProfile errorProfile) {
|
---|
225 | int numVertices = vertices.Length;
|
---|
226 | // find the centroid of all points except the worst one
|
---|
227 | Vector centroid = new Vector(numVertices - 1);
|
---|
228 | for (int i = 0; i < numVertices; i++) {
|
---|
229 | if (i != errorProfile.HighestIndex) {
|
---|
230 | centroid = centroid.Add(vertices[i]);
|
---|
231 | }
|
---|
232 | }
|
---|
233 | return centroid.Multiply(1.0d / (numVertices - 1));
|
---|
234 | }
|
---|
235 |
|
---|
236 | private sealed class ErrorProfile {
|
---|
237 | private int _highestIndex;
|
---|
238 | private int _nextHighestIndex;
|
---|
239 | private int _lowestIndex;
|
---|
240 |
|
---|
241 | public int HighestIndex {
|
---|
242 | get { return _highestIndex; }
|
---|
243 | set { _highestIndex = value; }
|
---|
244 | }
|
---|
245 |
|
---|
246 | public int NextHighestIndex {
|
---|
247 | get { return _nextHighestIndex; }
|
---|
248 | set { _nextHighestIndex = value; }
|
---|
249 | }
|
---|
250 |
|
---|
251 | public int LowestIndex {
|
---|
252 | get { return _lowestIndex; }
|
---|
253 | set { _lowestIndex = value; }
|
---|
254 | }
|
---|
255 | }
|
---|
256 | }
|
---|
257 | }
|
---|