[17848] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
| 3 | using System.Linq;
|
---|
| 4 | using System.Text;
|
---|
| 5 | using System.Threading.Tasks;
|
---|
| 6 |
|
---|
| 7 | namespace HeuristicLab.Algorithms.DataAnalysis.ContinuedFractionRegression {
|
---|
| 8 | public enum TerminationReason {
|
---|
| 9 | MaxFunctionEvaluations,
|
---|
| 10 | Converged,
|
---|
| 11 | Unspecified
|
---|
| 12 | }
|
---|
| 13 |
|
---|
| 14 | public delegate double ObjectiveFunctionDelegate(double[] constants);
|
---|
| 15 |
|
---|
| 16 |
|
---|
| 17 | public sealed class NelderMeadSimplex {
|
---|
| 18 | private static readonly double JITTER = 1e-10d; // a small value used to protect against floating point noise
|
---|
| 19 |
|
---|
| 20 | public static RegressionResult Regress(SimplexConstant[] simplexConstants, double convergenceTolerance, int maxEvaluations,
|
---|
| 21 | ObjectiveFunctionDelegate objectiveFunction) {
|
---|
| 22 | // confirm that we are in a position to commence
|
---|
| 23 | if (objectiveFunction == null)
|
---|
| 24 | throw new InvalidOperationException("ObjectiveFunction must be set to a valid ObjectiveFunctionDelegate");
|
---|
| 25 |
|
---|
| 26 | if (simplexConstants == null)
|
---|
| 27 | throw new InvalidOperationException("SimplexConstants must be initialized");
|
---|
| 28 |
|
---|
| 29 | // create the initial simplex
|
---|
| 30 | int numDimensions = simplexConstants.Length;
|
---|
| 31 | int numVertices = numDimensions + 1;
|
---|
| 32 | Vector[] vertices = _initializeVertices(simplexConstants);
|
---|
| 33 | double[] errorValues = new double[numVertices];
|
---|
| 34 |
|
---|
| 35 | int evaluationCount = 0;
|
---|
| 36 | TerminationReason terminationReason = TerminationReason.Unspecified;
|
---|
| 37 | ErrorProfile errorProfile;
|
---|
| 38 |
|
---|
| 39 | errorValues = _initializeErrorValues(vertices, objectiveFunction);
|
---|
| 40 |
|
---|
| 41 | // iterate until we converge, or complete our permitted number of iterations
|
---|
| 42 | while (true) {
|
---|
| 43 | errorProfile = _evaluateSimplex(errorValues);
|
---|
| 44 |
|
---|
| 45 | // see if the range in point heights is small enough to exit
|
---|
| 46 | if (_hasConverged(convergenceTolerance, errorProfile, errorValues)) {
|
---|
| 47 | terminationReason = TerminationReason.Converged;
|
---|
| 48 | break;
|
---|
| 49 | }
|
---|
| 50 |
|
---|
| 51 | // attempt a reflection of the simplex
|
---|
| 52 | double reflectionPointValue = _tryToScaleSimplex(-1.0, ref errorProfile, vertices, errorValues, objectiveFunction);
|
---|
| 53 | ++evaluationCount;
|
---|
| 54 | if (reflectionPointValue <= errorValues[errorProfile.LowestIndex]) {
|
---|
| 55 | // it's better than the best point, so attempt an expansion of the simplex
|
---|
| 56 | double expansionPointValue = _tryToScaleSimplex(2.0, ref errorProfile, vertices, errorValues, objectiveFunction);
|
---|
| 57 | ++evaluationCount;
|
---|
| 58 | } else if (reflectionPointValue >= errorValues[errorProfile.NextHighestIndex]) {
|
---|
| 59 | // it would be worse than the second best point, so attempt a contraction to look
|
---|
| 60 | // for an intermediate point
|
---|
| 61 | double currentWorst = errorValues[errorProfile.HighestIndex];
|
---|
| 62 | double contractionPointValue = _tryToScaleSimplex(0.5, ref errorProfile, vertices, errorValues, objectiveFunction);
|
---|
| 63 | ++evaluationCount;
|
---|
| 64 | if (contractionPointValue >= currentWorst) {
|
---|
| 65 | // that would be even worse, so let's try to contract uniformly towards the low point;
|
---|
| 66 | // don't bother to update the error profile, we'll do it at the start of the
|
---|
| 67 | // next iteration
|
---|
| 68 | _shrinkSimplex(errorProfile, vertices, errorValues, objectiveFunction);
|
---|
| 69 | evaluationCount += numVertices; // that required one function evaluation for each vertex; keep track
|
---|
| 70 | }
|
---|
| 71 | }
|
---|
| 72 | // check to see if we have exceeded our alloted number of evaluations
|
---|
| 73 | if (evaluationCount >= maxEvaluations) {
|
---|
| 74 | terminationReason = TerminationReason.MaxFunctionEvaluations;
|
---|
| 75 | break;
|
---|
| 76 | }
|
---|
| 77 | }
|
---|
| 78 | RegressionResult regressionResult = new RegressionResult(terminationReason,
|
---|
| 79 | vertices[errorProfile.LowestIndex].Components, errorValues[errorProfile.LowestIndex], evaluationCount);
|
---|
| 80 | return regressionResult;
|
---|
| 81 | }
|
---|
| 82 |
|
---|
| 83 | /// <summary>
|
---|
| 84 | /// Evaluate the objective function at each vertex to create a corresponding
|
---|
| 85 | /// list of error values for each vertex
|
---|
| 86 | /// </summary>
|
---|
| 87 | /// <param name="vertices"></param>
|
---|
| 88 | /// <returns></returns>
|
---|
| 89 | private static double[] _initializeErrorValues(Vector[] vertices, ObjectiveFunctionDelegate objectiveFunction) {
|
---|
| 90 | double[] errorValues = new double[vertices.Length];
|
---|
| 91 | for (int i = 0; i < vertices.Length; i++) {
|
---|
| 92 | errorValues[i] = objectiveFunction(vertices[i].Components);
|
---|
| 93 | }
|
---|
| 94 | return errorValues;
|
---|
| 95 | }
|
---|
| 96 |
|
---|
| 97 | /// <summary>
|
---|
| 98 | /// Check whether the points in the error profile have so little range that we
|
---|
| 99 | /// consider ourselves to have converged
|
---|
| 100 | /// </summary>
|
---|
| 101 | /// <param name="errorProfile"></param>
|
---|
| 102 | /// <param name="errorValues"></param>
|
---|
| 103 | /// <returns></returns>
|
---|
| 104 | private static bool _hasConverged(double convergenceTolerance, ErrorProfile errorProfile, double[] errorValues) {
|
---|
| 105 | double range = 2 * Math.Abs(errorValues[errorProfile.HighestIndex] - errorValues[errorProfile.LowestIndex]) /
|
---|
| 106 | (Math.Abs(errorValues[errorProfile.HighestIndex]) + Math.Abs(errorValues[errorProfile.LowestIndex]) + JITTER);
|
---|
| 107 |
|
---|
| 108 | if (range < convergenceTolerance) {
|
---|
| 109 | return true;
|
---|
| 110 | } else {
|
---|
| 111 | return false;
|
---|
| 112 | }
|
---|
| 113 | }
|
---|
| 114 |
|
---|
| 115 | /// <summary>
|
---|
| 116 | /// Examine all error values to determine the ErrorProfile
|
---|
| 117 | /// </summary>
|
---|
| 118 | /// <param name="errorValues"></param>
|
---|
| 119 | /// <returns></returns>
|
---|
| 120 | private static ErrorProfile _evaluateSimplex(double[] errorValues) {
|
---|
| 121 | ErrorProfile errorProfile = new ErrorProfile();
|
---|
| 122 | if (errorValues[0] > errorValues[1]) {
|
---|
| 123 | errorProfile.HighestIndex = 0;
|
---|
| 124 | errorProfile.NextHighestIndex = 1;
|
---|
| 125 | } else {
|
---|
| 126 | errorProfile.HighestIndex = 1;
|
---|
| 127 | errorProfile.NextHighestIndex = 0;
|
---|
| 128 | }
|
---|
| 129 |
|
---|
| 130 | for (int index = 0; index < errorValues.Length; index++) {
|
---|
| 131 | double errorValue = errorValues[index];
|
---|
| 132 | if (errorValue <= errorValues[errorProfile.LowestIndex]) {
|
---|
| 133 | errorProfile.LowestIndex = index;
|
---|
| 134 | }
|
---|
| 135 | if (errorValue > errorValues[errorProfile.HighestIndex]) {
|
---|
| 136 | errorProfile.NextHighestIndex = errorProfile.HighestIndex; // downgrade the current highest to next highest
|
---|
| 137 | errorProfile.HighestIndex = index;
|
---|
| 138 | } else if (errorValue > errorValues[errorProfile.NextHighestIndex] && index != errorProfile.HighestIndex) {
|
---|
| 139 | errorProfile.NextHighestIndex = index;
|
---|
| 140 | }
|
---|
| 141 | }
|
---|
| 142 |
|
---|
| 143 | return errorProfile;
|
---|
| 144 | }
|
---|
| 145 |
|
---|
| 146 | /// <summary>
|
---|
| 147 | /// Construct an initial simplex, given starting guesses for the constants, and
|
---|
| 148 | /// initial step sizes for each dimension
|
---|
| 149 | /// </summary>
|
---|
| 150 | /// <param name="simplexConstants"></param>
|
---|
| 151 | /// <returns></returns>
|
---|
| 152 | private static Vector[] _initializeVertices(SimplexConstant[] simplexConstants) {
|
---|
| 153 | int numDimensions = simplexConstants.Length;
|
---|
| 154 | Vector[] vertices = new Vector[numDimensions + 1];
|
---|
| 155 |
|
---|
| 156 | // define one point of the simplex as the given initial guesses
|
---|
| 157 | Vector p0 = new Vector(numDimensions);
|
---|
| 158 | for (int i = 0; i < numDimensions; i++) {
|
---|
| 159 | p0[i] = simplexConstants[i].Value;
|
---|
| 160 | }
|
---|
| 161 |
|
---|
| 162 | // now fill in the vertices, creating the additional points as:
|
---|
| 163 | // P(i) = P(0) + Scale(i) * UnitVector(i)
|
---|
| 164 | vertices[0] = p0;
|
---|
| 165 | for (int i = 0; i < numDimensions; i++) {
|
---|
| 166 | double scale = simplexConstants[i].InitialPerturbation;
|
---|
| 167 | Vector unitVector = new Vector(numDimensions);
|
---|
| 168 | unitVector[i] = 1;
|
---|
| 169 | vertices[i + 1] = p0.Add(unitVector.Multiply(scale));
|
---|
| 170 | }
|
---|
| 171 | return vertices;
|
---|
| 172 | }
|
---|
| 173 |
|
---|
| 174 | /// <summary>
|
---|
| 175 | /// Test a scaling operation of the high point, and replace it if it is an improvement
|
---|
| 176 | /// </summary>
|
---|
| 177 | /// <param name="scaleFactor"></param>
|
---|
| 178 | /// <param name="errorProfile"></param>
|
---|
| 179 | /// <param name="vertices"></param>
|
---|
| 180 | /// <param name="errorValues"></param>
|
---|
| 181 | /// <returns></returns>
|
---|
| 182 | private static double _tryToScaleSimplex(double scaleFactor, ref ErrorProfile errorProfile, Vector[] vertices,
|
---|
| 183 | double[] errorValues, ObjectiveFunctionDelegate objectiveFunction) {
|
---|
| 184 | // find the centroid through which we will reflect
|
---|
| 185 | Vector centroid = _computeCentroid(vertices, errorProfile);
|
---|
| 186 |
|
---|
| 187 | // define the vector from the centroid to the high point
|
---|
| 188 | Vector centroidToHighPoint = vertices[errorProfile.HighestIndex].Subtract(centroid);
|
---|
| 189 |
|
---|
| 190 | // scale and position the vector to determine the new trial point
|
---|
| 191 | Vector newPoint = centroidToHighPoint.Multiply(scaleFactor).Add(centroid);
|
---|
| 192 |
|
---|
| 193 | // evaluate the new point
|
---|
| 194 | double newErrorValue = objectiveFunction(newPoint.Components);
|
---|
| 195 |
|
---|
| 196 | // if it's better, replace the old high point
|
---|
| 197 | if (newErrorValue < errorValues[errorProfile.HighestIndex]) {
|
---|
| 198 | vertices[errorProfile.HighestIndex] = newPoint;
|
---|
| 199 | errorValues[errorProfile.HighestIndex] = newErrorValue;
|
---|
| 200 | }
|
---|
| 201 |
|
---|
| 202 | return newErrorValue;
|
---|
| 203 | }
|
---|
| 204 |
|
---|
| 205 | /// <summary>
|
---|
| 206 | /// Contract the simplex uniformly around the lowest point
|
---|
| 207 | /// </summary>
|
---|
| 208 | /// <param name="errorProfile"></param>
|
---|
| 209 | /// <param name="vertices"></param>
|
---|
| 210 | /// <param name="errorValues"></param>
|
---|
| 211 | private static void _shrinkSimplex(ErrorProfile errorProfile, Vector[] vertices, double[] errorValues,
|
---|
| 212 | ObjectiveFunctionDelegate objectiveFunction) {
|
---|
| 213 | Vector lowestVertex = vertices[errorProfile.LowestIndex];
|
---|
| 214 | for (int i = 0; i < vertices.Length; i++) {
|
---|
| 215 | if (i != errorProfile.LowestIndex) {
|
---|
| 216 | vertices[i] = (vertices[i].Add(lowestVertex)).Multiply(0.5);
|
---|
| 217 | errorValues[i] = objectiveFunction(vertices[i].Components);
|
---|
| 218 | }
|
---|
| 219 | }
|
---|
| 220 | }
|
---|
| 221 |
|
---|
| 222 | /// <summary>
|
---|
| 223 | /// Compute the centroid of all points except the worst
|
---|
| 224 | /// </summary>
|
---|
| 225 | /// <param name="vertices"></param>
|
---|
| 226 | /// <param name="errorProfile"></param>
|
---|
| 227 | /// <returns></returns>
|
---|
| 228 | private static Vector _computeCentroid(Vector[] vertices, ErrorProfile errorProfile) {
|
---|
| 229 | int numVertices = vertices.Length;
|
---|
| 230 | // find the centroid of all points except the worst one
|
---|
| 231 | Vector centroid = new Vector(numVertices - 1);
|
---|
| 232 | for (int i = 0; i < numVertices; i++) {
|
---|
| 233 | if (i != errorProfile.HighestIndex) {
|
---|
| 234 | centroid = centroid.Add(vertices[i]);
|
---|
| 235 | }
|
---|
| 236 | }
|
---|
| 237 | return centroid.Multiply(1.0d / (numVertices - 1));
|
---|
| 238 | }
|
---|
| 239 |
|
---|
| 240 | private sealed class ErrorProfile {
|
---|
| 241 | private int _highestIndex;
|
---|
| 242 | private int _nextHighestIndex;
|
---|
| 243 | private int _lowestIndex;
|
---|
| 244 |
|
---|
| 245 | public int HighestIndex {
|
---|
| 246 | get { return _highestIndex; }
|
---|
| 247 | set { _highestIndex = value; }
|
---|
| 248 | }
|
---|
| 249 |
|
---|
| 250 | public int NextHighestIndex {
|
---|
| 251 | get { return _nextHighestIndex; }
|
---|
| 252 | set { _nextHighestIndex = value; }
|
---|
| 253 | }
|
---|
| 254 |
|
---|
| 255 | public int LowestIndex {
|
---|
| 256 | get { return _lowestIndex; }
|
---|
| 257 | set { _lowestIndex = value; }
|
---|
| 258 | }
|
---|
| 259 | }
|
---|
| 260 | }
|
---|
| 261 | }
|
---|