[17848] | 1 | using System;
|
---|
| 2 |
|
---|
| 3 | namespace HeuristicLab.Algorithms.DataAnalysis.ContinuedFractionRegression {
|
---|
| 4 | public enum TerminationReason {
|
---|
| 5 | MaxFunctionEvaluations,
|
---|
| 6 | Converged,
|
---|
| 7 | Unspecified
|
---|
| 8 | }
|
---|
| 9 |
|
---|
| 10 | public delegate double ObjectiveFunctionDelegate(double[] constants);
|
---|
| 11 |
|
---|
| 12 |
|
---|
| 13 | public sealed class NelderMeadSimplex {
|
---|
| 14 | private static readonly double JITTER = 1e-10d; // a small value used to protect against floating point noise
|
---|
| 15 |
|
---|
| 16 | public static RegressionResult Regress(SimplexConstant[] simplexConstants, double convergenceTolerance, int maxEvaluations,
|
---|
| 17 | ObjectiveFunctionDelegate objectiveFunction) {
|
---|
| 18 | // confirm that we are in a position to commence
|
---|
| 19 | if (objectiveFunction == null)
|
---|
| 20 | throw new InvalidOperationException("ObjectiveFunction must be set to a valid ObjectiveFunctionDelegate");
|
---|
| 21 |
|
---|
| 22 | if (simplexConstants == null)
|
---|
| 23 | throw new InvalidOperationException("SimplexConstants must be initialized");
|
---|
| 24 |
|
---|
| 25 | // create the initial simplex
|
---|
| 26 | int numDimensions = simplexConstants.Length;
|
---|
| 27 | int numVertices = numDimensions + 1;
|
---|
| 28 | Vector[] vertices = _initializeVertices(simplexConstants);
|
---|
| 29 | double[] errorValues = new double[numVertices];
|
---|
| 30 |
|
---|
| 31 | int evaluationCount = 0;
|
---|
| 32 | TerminationReason terminationReason = TerminationReason.Unspecified;
|
---|
| 33 | ErrorProfile errorProfile;
|
---|
| 34 |
|
---|
| 35 | errorValues = _initializeErrorValues(vertices, objectiveFunction);
|
---|
| 36 |
|
---|
| 37 | // iterate until we converge, or complete our permitted number of iterations
|
---|
| 38 | while (true) {
|
---|
| 39 | errorProfile = _evaluateSimplex(errorValues);
|
---|
| 40 |
|
---|
| 41 | // see if the range in point heights is small enough to exit
|
---|
| 42 | if (_hasConverged(convergenceTolerance, errorProfile, errorValues)) {
|
---|
| 43 | terminationReason = TerminationReason.Converged;
|
---|
| 44 | break;
|
---|
| 45 | }
|
---|
| 46 |
|
---|
| 47 | // attempt a reflection of the simplex
|
---|
| 48 | double reflectionPointValue = _tryToScaleSimplex(-1.0, ref errorProfile, vertices, errorValues, objectiveFunction);
|
---|
| 49 | ++evaluationCount;
|
---|
| 50 | if (reflectionPointValue <= errorValues[errorProfile.LowestIndex]) {
|
---|
| 51 | // it's better than the best point, so attempt an expansion of the simplex
|
---|
| 52 | double expansionPointValue = _tryToScaleSimplex(2.0, ref errorProfile, vertices, errorValues, objectiveFunction);
|
---|
| 53 | ++evaluationCount;
|
---|
| 54 | } else if (reflectionPointValue >= errorValues[errorProfile.NextHighestIndex]) {
|
---|
| 55 | // it would be worse than the second best point, so attempt a contraction to look
|
---|
| 56 | // for an intermediate point
|
---|
| 57 | double currentWorst = errorValues[errorProfile.HighestIndex];
|
---|
| 58 | double contractionPointValue = _tryToScaleSimplex(0.5, ref errorProfile, vertices, errorValues, objectiveFunction);
|
---|
| 59 | ++evaluationCount;
|
---|
| 60 | if (contractionPointValue >= currentWorst) {
|
---|
| 61 | // that would be even worse, so let's try to contract uniformly towards the low point;
|
---|
| 62 | // don't bother to update the error profile, we'll do it at the start of the
|
---|
| 63 | // next iteration
|
---|
| 64 | _shrinkSimplex(errorProfile, vertices, errorValues, objectiveFunction);
|
---|
| 65 | evaluationCount += numVertices; // that required one function evaluation for each vertex; keep track
|
---|
| 66 | }
|
---|
| 67 | }
|
---|
| 68 | // check to see if we have exceeded our alloted number of evaluations
|
---|
| 69 | if (evaluationCount >= maxEvaluations) {
|
---|
| 70 | terminationReason = TerminationReason.MaxFunctionEvaluations;
|
---|
| 71 | break;
|
---|
| 72 | }
|
---|
| 73 | }
|
---|
| 74 | RegressionResult regressionResult = new RegressionResult(terminationReason,
|
---|
| 75 | vertices[errorProfile.LowestIndex].Components, errorValues[errorProfile.LowestIndex], evaluationCount);
|
---|
| 76 | return regressionResult;
|
---|
| 77 | }
|
---|
| 78 |
|
---|
| 79 | /// <summary>
|
---|
| 80 | /// Evaluate the objective function at each vertex to create a corresponding
|
---|
| 81 | /// list of error values for each vertex
|
---|
| 82 | /// </summary>
|
---|
| 83 | /// <param name="vertices"></param>
|
---|
| 84 | /// <returns></returns>
|
---|
| 85 | private static double[] _initializeErrorValues(Vector[] vertices, ObjectiveFunctionDelegate objectiveFunction) {
|
---|
| 86 | double[] errorValues = new double[vertices.Length];
|
---|
| 87 | for (int i = 0; i < vertices.Length; i++) {
|
---|
| 88 | errorValues[i] = objectiveFunction(vertices[i].Components);
|
---|
| 89 | }
|
---|
| 90 | return errorValues;
|
---|
| 91 | }
|
---|
| 92 |
|
---|
| 93 | /// <summary>
|
---|
| 94 | /// Check whether the points in the error profile have so little range that we
|
---|
| 95 | /// consider ourselves to have converged
|
---|
| 96 | /// </summary>
|
---|
| 97 | /// <param name="errorProfile"></param>
|
---|
| 98 | /// <param name="errorValues"></param>
|
---|
| 99 | /// <returns></returns>
|
---|
| 100 | private static bool _hasConverged(double convergenceTolerance, ErrorProfile errorProfile, double[] errorValues) {
|
---|
| 101 | double range = 2 * Math.Abs(errorValues[errorProfile.HighestIndex] - errorValues[errorProfile.LowestIndex]) /
|
---|
| 102 | (Math.Abs(errorValues[errorProfile.HighestIndex]) + Math.Abs(errorValues[errorProfile.LowestIndex]) + JITTER);
|
---|
| 103 |
|
---|
| 104 | if (range < convergenceTolerance) {
|
---|
| 105 | return true;
|
---|
| 106 | } else {
|
---|
| 107 | return false;
|
---|
| 108 | }
|
---|
| 109 | }
|
---|
| 110 |
|
---|
| 111 | /// <summary>
|
---|
| 112 | /// Examine all error values to determine the ErrorProfile
|
---|
| 113 | /// </summary>
|
---|
| 114 | /// <param name="errorValues"></param>
|
---|
| 115 | /// <returns></returns>
|
---|
| 116 | private static ErrorProfile _evaluateSimplex(double[] errorValues) {
|
---|
| 117 | ErrorProfile errorProfile = new ErrorProfile();
|
---|
| 118 | if (errorValues[0] > errorValues[1]) {
|
---|
| 119 | errorProfile.HighestIndex = 0;
|
---|
| 120 | errorProfile.NextHighestIndex = 1;
|
---|
| 121 | } else {
|
---|
| 122 | errorProfile.HighestIndex = 1;
|
---|
| 123 | errorProfile.NextHighestIndex = 0;
|
---|
| 124 | }
|
---|
| 125 |
|
---|
| 126 | for (int index = 0; index < errorValues.Length; index++) {
|
---|
| 127 | double errorValue = errorValues[index];
|
---|
| 128 | if (errorValue <= errorValues[errorProfile.LowestIndex]) {
|
---|
| 129 | errorProfile.LowestIndex = index;
|
---|
| 130 | }
|
---|
| 131 | if (errorValue > errorValues[errorProfile.HighestIndex]) {
|
---|
| 132 | errorProfile.NextHighestIndex = errorProfile.HighestIndex; // downgrade the current highest to next highest
|
---|
| 133 | errorProfile.HighestIndex = index;
|
---|
| 134 | } else if (errorValue > errorValues[errorProfile.NextHighestIndex] && index != errorProfile.HighestIndex) {
|
---|
| 135 | errorProfile.NextHighestIndex = index;
|
---|
| 136 | }
|
---|
| 137 | }
|
---|
| 138 |
|
---|
| 139 | return errorProfile;
|
---|
| 140 | }
|
---|
| 141 |
|
---|
| 142 | /// <summary>
|
---|
| 143 | /// Construct an initial simplex, given starting guesses for the constants, and
|
---|
| 144 | /// initial step sizes for each dimension
|
---|
| 145 | /// </summary>
|
---|
| 146 | /// <param name="simplexConstants"></param>
|
---|
| 147 | /// <returns></returns>
|
---|
| 148 | private static Vector[] _initializeVertices(SimplexConstant[] simplexConstants) {
|
---|
| 149 | int numDimensions = simplexConstants.Length;
|
---|
| 150 | Vector[] vertices = new Vector[numDimensions + 1];
|
---|
| 151 |
|
---|
| 152 | // define one point of the simplex as the given initial guesses
|
---|
| 153 | Vector p0 = new Vector(numDimensions);
|
---|
| 154 | for (int i = 0; i < numDimensions; i++) {
|
---|
| 155 | p0[i] = simplexConstants[i].Value;
|
---|
| 156 | }
|
---|
| 157 |
|
---|
| 158 | // now fill in the vertices, creating the additional points as:
|
---|
| 159 | // P(i) = P(0) + Scale(i) * UnitVector(i)
|
---|
| 160 | vertices[0] = p0;
|
---|
| 161 | for (int i = 0; i < numDimensions; i++) {
|
---|
| 162 | double scale = simplexConstants[i].InitialPerturbation;
|
---|
| 163 | Vector unitVector = new Vector(numDimensions);
|
---|
| 164 | unitVector[i] = 1;
|
---|
| 165 | vertices[i + 1] = p0.Add(unitVector.Multiply(scale));
|
---|
| 166 | }
|
---|
| 167 | return vertices;
|
---|
| 168 | }
|
---|
| 169 |
|
---|
| 170 | /// <summary>
|
---|
| 171 | /// Test a scaling operation of the high point, and replace it if it is an improvement
|
---|
| 172 | /// </summary>
|
---|
| 173 | /// <param name="scaleFactor"></param>
|
---|
| 174 | /// <param name="errorProfile"></param>
|
---|
| 175 | /// <param name="vertices"></param>
|
---|
| 176 | /// <param name="errorValues"></param>
|
---|
| 177 | /// <returns></returns>
|
---|
| 178 | private static double _tryToScaleSimplex(double scaleFactor, ref ErrorProfile errorProfile, Vector[] vertices,
|
---|
| 179 | double[] errorValues, ObjectiveFunctionDelegate objectiveFunction) {
|
---|
| 180 | // find the centroid through which we will reflect
|
---|
| 181 | Vector centroid = _computeCentroid(vertices, errorProfile);
|
---|
| 182 |
|
---|
| 183 | // define the vector from the centroid to the high point
|
---|
| 184 | Vector centroidToHighPoint = vertices[errorProfile.HighestIndex].Subtract(centroid);
|
---|
| 185 |
|
---|
| 186 | // scale and position the vector to determine the new trial point
|
---|
| 187 | Vector newPoint = centroidToHighPoint.Multiply(scaleFactor).Add(centroid);
|
---|
| 188 |
|
---|
| 189 | // evaluate the new point
|
---|
| 190 | double newErrorValue = objectiveFunction(newPoint.Components);
|
---|
| 191 |
|
---|
| 192 | // if it's better, replace the old high point
|
---|
| 193 | if (newErrorValue < errorValues[errorProfile.HighestIndex]) {
|
---|
| 194 | vertices[errorProfile.HighestIndex] = newPoint;
|
---|
| 195 | errorValues[errorProfile.HighestIndex] = newErrorValue;
|
---|
| 196 | }
|
---|
| 197 |
|
---|
| 198 | return newErrorValue;
|
---|
| 199 | }
|
---|
| 200 |
|
---|
| 201 | /// <summary>
|
---|
| 202 | /// Contract the simplex uniformly around the lowest point
|
---|
| 203 | /// </summary>
|
---|
| 204 | /// <param name="errorProfile"></param>
|
---|
| 205 | /// <param name="vertices"></param>
|
---|
| 206 | /// <param name="errorValues"></param>
|
---|
| 207 | private static void _shrinkSimplex(ErrorProfile errorProfile, Vector[] vertices, double[] errorValues,
|
---|
| 208 | ObjectiveFunctionDelegate objectiveFunction) {
|
---|
| 209 | Vector lowestVertex = vertices[errorProfile.LowestIndex];
|
---|
| 210 | for (int i = 0; i < vertices.Length; i++) {
|
---|
| 211 | if (i != errorProfile.LowestIndex) {
|
---|
| 212 | vertices[i] = (vertices[i].Add(lowestVertex)).Multiply(0.5);
|
---|
| 213 | errorValues[i] = objectiveFunction(vertices[i].Components);
|
---|
| 214 | }
|
---|
| 215 | }
|
---|
| 216 | }
|
---|
| 217 |
|
---|
| 218 | /// <summary>
|
---|
| 219 | /// Compute the centroid of all points except the worst
|
---|
| 220 | /// </summary>
|
---|
| 221 | /// <param name="vertices"></param>
|
---|
| 222 | /// <param name="errorProfile"></param>
|
---|
| 223 | /// <returns></returns>
|
---|
| 224 | private static Vector _computeCentroid(Vector[] vertices, ErrorProfile errorProfile) {
|
---|
| 225 | int numVertices = vertices.Length;
|
---|
| 226 | // find the centroid of all points except the worst one
|
---|
| 227 | Vector centroid = new Vector(numVertices - 1);
|
---|
| 228 | for (int i = 0; i < numVertices; i++) {
|
---|
| 229 | if (i != errorProfile.HighestIndex) {
|
---|
| 230 | centroid = centroid.Add(vertices[i]);
|
---|
| 231 | }
|
---|
| 232 | }
|
---|
| 233 | return centroid.Multiply(1.0d / (numVertices - 1));
|
---|
| 234 | }
|
---|
| 235 |
|
---|
| 236 | private sealed class ErrorProfile {
|
---|
| 237 | private int _highestIndex;
|
---|
| 238 | private int _nextHighestIndex;
|
---|
| 239 | private int _lowestIndex;
|
---|
| 240 |
|
---|
| 241 | public int HighestIndex {
|
---|
| 242 | get { return _highestIndex; }
|
---|
| 243 | set { _highestIndex = value; }
|
---|
| 244 | }
|
---|
| 245 |
|
---|
| 246 | public int NextHighestIndex {
|
---|
| 247 | get { return _nextHighestIndex; }
|
---|
| 248 | set { _nextHighestIndex = value; }
|
---|
| 249 | }
|
---|
| 250 |
|
---|
| 251 | public int LowestIndex {
|
---|
| 252 | get { return _lowestIndex; }
|
---|
| 253 | set { _lowestIndex = value; }
|
---|
| 254 | }
|
---|
| 255 | }
|
---|
| 256 | }
|
---|
| 257 | }
|
---|