Free cookie consent management tool by TermsFeed Policy Generator

Changeset 14258


Ignore:
Timestamp:
08/17/16 12:19:24 (8 years ago)
Author:
gkronber
Message:

#2657: added random restarts for NonlinearRegression (curve fitting) algorithm

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NonlinearRegression/NonlinearRegression.cs

    r14109 r14258  
    3333using HeuristicLab.Problems.DataAnalysis.Symbolic;
    3434using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
     35using HeuristicLab.Random;
    3536
    3637namespace HeuristicLab.Algorithms.DataAnalysis {
     
    4546    private const string ModelStructureParameterName = "Model structure";
    4647    private const string IterationsParameterName = "Iterations";
     48    private const string RestartsParameterName = "Restarts";
     49    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
     50    private const string SeedParameterName = "Seed";
    4751
    4852    public IFixedValueParameter<StringValue> ModelStructureParameter {
     
    5155    public IFixedValueParameter<IntValue> IterationsParameter {
    5256      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
     57    }
     58
     59    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
     60      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
     61    }
     62
     63    public IFixedValueParameter<IntValue> SeedParameter {
     64      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
     65    }
     66
     67    public IFixedValueParameter<IntValue> RestartsParameter {
     68      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
    5369    }
    5470
     
    6379    }
    6480
     81    public int Restarts {
     82      get { return RestartsParameter.Value.Value; }
     83      set { RestartsParameter.Value.Value = value; }
     84    }
     85
     86    public int Seed {
     87      get { return SeedParameter.Value.Value; }
     88      set { SeedParameter.Value.Value = value; }
     89    }
     90
     91    public bool SetSeedRandomly {
     92      get { return SetSeedRandomlyParameter.Value.Value; }
     93      set { SetSeedRandomlyParameter.Value.Value = value; }
     94    }
    6595
    6696    [StorableConstructor]
     
    74104      Parameters.Add(new FixedValueParameter<StringValue>(ModelStructureParameterName, "The function for which the parameters must be fit (only numeric constants are tuned).", new StringValue("1.0 * x*x + 0.0")));
    75105      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "The maximum number of iterations for constants optimization.", new IntValue(200)));
     106      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of independent random restarts", new IntValue(10)));
     107      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The PRNG seed value.", new IntValue()));
     108      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "Switch to determine if the random number seed should be initialized randomly.", new BoolValue(true)));
    76109    }
     110
    77111    [StorableHook(HookType.AfterDeserialization)]
    78     private void AfterDeserialization() { }
     112    private void AfterDeserialization() {
     113      // BackwardsCompatibility3.3
     114      #region Backwards compatible code, remove with 3.4
     115      if (!Parameters.ContainsKey(RestartsParameterName))
     116        Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of independent random restarts", new IntValue(1)));
     117      if (!Parameters.ContainsKey(SeedParameterName))
     118        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The PRNG seed value.", new IntValue()));
     119      if (!Parameters.ContainsKey(SetSeedRandomlyParameterName))
     120        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "Switch to determine if the random number seed should be initialized randomly.", new BoolValue(true)));
     121      #endregion
     122    }
    79123
    80124    public override IDeepCloneable Clone(Cloner cloner) {
     
    84128    #region nonlinear regression
    85129    protected override void Run() {
    86       var solution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations);
    87       Results.Add(new Result(RegressionSolutionResultName, "The nonlinear regression solution.", solution));
    88       Results.Add(new Result("Root mean square error (train)", "The root of the mean of squared errors of the regression solution on the training set.", new DoubleValue(solution.TrainingRootMeanSquaredError)));
    89       Results.Add(new Result("Root mean square error (test)", "The root of the mean of squared errors of the regression solution on the test set.", new DoubleValue(solution.TestRootMeanSquaredError)));
     130      if (SetSeedRandomly) Seed = (new System.Random()).Next();
     131      var rand = new MersenneTwister((uint)Seed);
     132      IRegressionSolution bestSolution = null;
     133      for (int r = 0; r < Restarts; r++) {
     134        var solution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations, rand);
     135        if (bestSolution == null || solution.TrainingRootMeanSquaredError < bestSolution.TrainingRootMeanSquaredError) {
     136          bestSolution = solution;
     137        }
     138      }
     139
     140      Results.Add(new Result(RegressionSolutionResultName, "The nonlinear regression solution.", bestSolution));
     141      Results.Add(new Result("Root mean square error (train)", "The root of the mean of squared errors of the regression solution on the training set.", new DoubleValue(bestSolution.TrainingRootMeanSquaredError)));
     142      Results.Add(new Result("Root mean square error (test)", "The root of the mean of squared errors of the regression solution on the test set.", new DoubleValue(bestSolution.TestRootMeanSquaredError)));
     143
    90144    }
    91145
    92     public static ISymbolicRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, string modelStructure, int maxIterations) {
     146    /// <summary>
     147    /// Fits a model to the data by optimizing the numeric constants.
     148    /// Model is specified as infix expression containing variable names and numbers.
     149    /// The starting point for the numeric constants is initialized randomly if a random number generator is specified (~N(0,1)). Otherwise the user specified constants are
     150    /// used as a starting point.
     151    /// </summary>
     152    /// <param name="problemData">Training and test data</param>
     153    /// <param name="modelStructure">The function as infix expression</param>
     154    /// <param name="maxIterations">Number of constant optimization iterations (using Levenberg-Marquardt algorithm)</param>
     155    /// <param name="random">Optional random number generator for random initialization of numeric constants.</param>
     156    /// <returns></returns>
     157    public static ISymbolicRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, string modelStructure, int maxIterations, IRandom random = null) {
    93158      var parser = new InfixExpressionParser();
    94159      var tree = parser.Parse(modelStructure);
    95       var simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
    96      
     160
    97161      if (!SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree)) throw new ArgumentException("The optimizer does not support the specified model structure.");
    98162
     163      // initialize constants randomly
     164      if (random != null) {
     165        foreach (var node in tree.IterateNodesPrefix().OfType<ConstantTreeNode>()) {
     166          node.Value = NormalDistributedRandom.NextDouble(random, 0, 1);
     167        }
     168      }
    99169      var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
    100       SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, problemData, problemData.TrainingIndices,
     170
     171      SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, problemData, problemData.TrainingIndices,
    101172        applyLinearScaling: false, maxIterations: maxIterations,
    102173        updateVariableWeights: false, updateConstantsInTree: true);
    103 
    104174
    105175      var scaledModel = new SymbolicRegressionModel(problemData.TargetVariable, tree, (ISymbolicDataAnalysisExpressionTreeInterpreter)interpreter.Clone());
Note: See TracChangeset for help on using the changeset viewer.