using System; namespace HeuristicLab.Algorithms.DataAnalysis.ContinuedFractionRegression { public enum TerminationReason { MaxFunctionEvaluations, Converged, Unspecified } public delegate double ObjectiveFunctionDelegate(double[] constants); public sealed class NelderMeadSimplex { private static readonly double JITTER = 1e-10d; // a small value used to protect against floating point noise public static RegressionResult Regress(SimplexConstant[] simplexConstants, double convergenceTolerance, int maxEvaluations, ObjectiveFunctionDelegate objectiveFunction) { // confirm that we are in a position to commence if (objectiveFunction == null) throw new InvalidOperationException("ObjectiveFunction must be set to a valid ObjectiveFunctionDelegate"); if (simplexConstants == null) throw new InvalidOperationException("SimplexConstants must be initialized"); // create the initial simplex int numDimensions = simplexConstants.Length; int numVertices = numDimensions + 1; Vector[] vertices = _initializeVertices(simplexConstants); double[] errorValues = new double[numVertices]; int evaluationCount = 0; TerminationReason terminationReason = TerminationReason.Unspecified; ErrorProfile errorProfile; errorValues = _initializeErrorValues(vertices, objectiveFunction); // iterate until we converge, or complete our permitted number of iterations while (true) { errorProfile = _evaluateSimplex(errorValues); // see if the range in point heights is small enough to exit if (_hasConverged(convergenceTolerance, errorProfile, errorValues)) { terminationReason = TerminationReason.Converged; break; } // attempt a reflection of the simplex double reflectionPointValue = _tryToScaleSimplex(-1.0, ref errorProfile, vertices, errorValues, objectiveFunction); ++evaluationCount; if (reflectionPointValue <= errorValues[errorProfile.LowestIndex]) { // it's better than the best point, so attempt an expansion of the simplex double expansionPointValue = _tryToScaleSimplex(2.0, ref errorProfile, vertices, errorValues, objectiveFunction); ++evaluationCount; } else if (reflectionPointValue >= errorValues[errorProfile.NextHighestIndex]) { // it would be worse than the second best point, so attempt a contraction to look // for an intermediate point double currentWorst = errorValues[errorProfile.HighestIndex]; double contractionPointValue = _tryToScaleSimplex(0.5, ref errorProfile, vertices, errorValues, objectiveFunction); ++evaluationCount; if (contractionPointValue >= currentWorst) { // that would be even worse, so let's try to contract uniformly towards the low point; // don't bother to update the error profile, we'll do it at the start of the // next iteration _shrinkSimplex(errorProfile, vertices, errorValues, objectiveFunction); evaluationCount += numVertices; // that required one function evaluation for each vertex; keep track } } // check to see if we have exceeded our alloted number of evaluations if (evaluationCount >= maxEvaluations) { terminationReason = TerminationReason.MaxFunctionEvaluations; break; } } RegressionResult regressionResult = new RegressionResult(terminationReason, vertices[errorProfile.LowestIndex].Components, errorValues[errorProfile.LowestIndex], evaluationCount); return regressionResult; } /// /// Evaluate the objective function at each vertex to create a corresponding /// list of error values for each vertex /// /// /// private static double[] _initializeErrorValues(Vector[] vertices, ObjectiveFunctionDelegate objectiveFunction) { double[] errorValues = new double[vertices.Length]; for (int i = 0; i < vertices.Length; i++) { errorValues[i] = objectiveFunction(vertices[i].Components); } return errorValues; } /// /// Check whether the points in the error profile have so little range that we /// consider ourselves to have converged /// /// /// /// private static bool _hasConverged(double convergenceTolerance, ErrorProfile errorProfile, double[] errorValues) { double range = 2 * Math.Abs(errorValues[errorProfile.HighestIndex] - errorValues[errorProfile.LowestIndex]) / (Math.Abs(errorValues[errorProfile.HighestIndex]) + Math.Abs(errorValues[errorProfile.LowestIndex]) + JITTER); if (range < convergenceTolerance) { return true; } else { return false; } } /// /// Examine all error values to determine the ErrorProfile /// /// /// private static ErrorProfile _evaluateSimplex(double[] errorValues) { ErrorProfile errorProfile = new ErrorProfile(); if (errorValues[0] > errorValues[1]) { errorProfile.HighestIndex = 0; errorProfile.NextHighestIndex = 1; } else { errorProfile.HighestIndex = 1; errorProfile.NextHighestIndex = 0; } for (int index = 0; index < errorValues.Length; index++) { double errorValue = errorValues[index]; if (errorValue <= errorValues[errorProfile.LowestIndex]) { errorProfile.LowestIndex = index; } if (errorValue > errorValues[errorProfile.HighestIndex]) { errorProfile.NextHighestIndex = errorProfile.HighestIndex; // downgrade the current highest to next highest errorProfile.HighestIndex = index; } else if (errorValue > errorValues[errorProfile.NextHighestIndex] && index != errorProfile.HighestIndex) { errorProfile.NextHighestIndex = index; } } return errorProfile; } /// /// Construct an initial simplex, given starting guesses for the constants, and /// initial step sizes for each dimension /// /// /// private static Vector[] _initializeVertices(SimplexConstant[] simplexConstants) { int numDimensions = simplexConstants.Length; Vector[] vertices = new Vector[numDimensions + 1]; // define one point of the simplex as the given initial guesses Vector p0 = new Vector(numDimensions); for (int i = 0; i < numDimensions; i++) { p0[i] = simplexConstants[i].Value; } // now fill in the vertices, creating the additional points as: // P(i) = P(0) + Scale(i) * UnitVector(i) vertices[0] = p0; for (int i = 0; i < numDimensions; i++) { double scale = simplexConstants[i].InitialPerturbation; Vector unitVector = new Vector(numDimensions); unitVector[i] = 1; vertices[i + 1] = p0.Add(unitVector.Multiply(scale)); } return vertices; } /// /// Test a scaling operation of the high point, and replace it if it is an improvement /// /// /// /// /// /// private static double _tryToScaleSimplex(double scaleFactor, ref ErrorProfile errorProfile, Vector[] vertices, double[] errorValues, ObjectiveFunctionDelegate objectiveFunction) { // find the centroid through which we will reflect Vector centroid = _computeCentroid(vertices, errorProfile); // define the vector from the centroid to the high point Vector centroidToHighPoint = vertices[errorProfile.HighestIndex].Subtract(centroid); // scale and position the vector to determine the new trial point Vector newPoint = centroidToHighPoint.Multiply(scaleFactor).Add(centroid); // evaluate the new point double newErrorValue = objectiveFunction(newPoint.Components); // if it's better, replace the old high point if (newErrorValue < errorValues[errorProfile.HighestIndex]) { vertices[errorProfile.HighestIndex] = newPoint; errorValues[errorProfile.HighestIndex] = newErrorValue; } return newErrorValue; } /// /// Contract the simplex uniformly around the lowest point /// /// /// /// private static void _shrinkSimplex(ErrorProfile errorProfile, Vector[] vertices, double[] errorValues, ObjectiveFunctionDelegate objectiveFunction) { Vector lowestVertex = vertices[errorProfile.LowestIndex]; for (int i = 0; i < vertices.Length; i++) { if (i != errorProfile.LowestIndex) { vertices[i] = (vertices[i].Add(lowestVertex)).Multiply(0.5); errorValues[i] = objectiveFunction(vertices[i].Components); } } } /// /// Compute the centroid of all points except the worst /// /// /// /// private static Vector _computeCentroid(Vector[] vertices, ErrorProfile errorProfile) { int numVertices = vertices.Length; // find the centroid of all points except the worst one Vector centroid = new Vector(numVertices - 1); for (int i = 0; i < numVertices; i++) { if (i != errorProfile.HighestIndex) { centroid = centroid.Add(vertices[i]); } } return centroid.Multiply(1.0d / (numVertices - 1)); } private sealed class ErrorProfile { private int _highestIndex; private int _nextHighestIndex; private int _lowestIndex; public int HighestIndex { get { return _highestIndex; } set { _highestIndex = value; } } public int NextHighestIndex { get { return _nextHighestIndex; } set { _nextHighestIndex = value; } } public int LowestIndex { get { return _lowestIndex; } set { _lowestIndex = value; } } } } }