Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/14/17 19:08:39 (8 years ago)
Author:
gkronber
Message:

#2657,#2677 merged r14258, r14316, r14319 and 14347.

Location:
stable
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Algorithms.DataAnalysis

  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/NonlinearRegression/NonlinearRegression.cs

    r14116 r14564  
    2121
    2222using System;
    23 using System.Collections.Generic;
    2423using System.Linq;
     24using HeuristicLab.Analysis;
    2525using HeuristicLab.Common;
    2626using HeuristicLab.Core;
    2727using HeuristicLab.Data;
     28using HeuristicLab.Optimization;
    2829using HeuristicLab.Parameters;
    29 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    30 using HeuristicLab.Optimization;
    3130using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3231using HeuristicLab.Problems.DataAnalysis;
    3332using HeuristicLab.Problems.DataAnalysis.Symbolic;
    3433using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
     34using HeuristicLab.Random;
    3535
    3636namespace HeuristicLab.Algorithms.DataAnalysis {
     
    4545    private const string ModelStructureParameterName = "Model structure";
    4646    private const string IterationsParameterName = "Iterations";
     47    private const string RestartsParameterName = "Restarts";
     48    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
     49    private const string SeedParameterName = "Seed";
     50    private const string InitParamsRandomlyParameterName = "InitializeParametersRandomly";
    4751
    4852    public IFixedValueParameter<StringValue> ModelStructureParameter {
     
    5155    public IFixedValueParameter<IntValue> IterationsParameter {
    5256      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
     57    }
     58
     59    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
     60      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
     61    }
     62
     63    public IFixedValueParameter<IntValue> SeedParameter {
     64      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
     65    }
     66
     67    public IFixedValueParameter<IntValue> RestartsParameter {
     68      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
     69    }
     70
     71    public IFixedValueParameter<BoolValue> InitParametersRandomlyParameter {
     72      get { return (IFixedValueParameter<BoolValue>)Parameters[InitParamsRandomlyParameterName]; }
    5373    }
    5474
     
    6383    }
    6484
     85    public int Restarts {
     86      get { return RestartsParameter.Value.Value; }
     87      set { RestartsParameter.Value.Value = value; }
     88    }
     89
     90    public int Seed {
     91      get { return SeedParameter.Value.Value; }
     92      set { SeedParameter.Value.Value = value; }
     93    }
     94
     95    public bool SetSeedRandomly {
     96      get { return SetSeedRandomlyParameter.Value.Value; }
     97      set { SetSeedRandomlyParameter.Value.Value = value; }
     98    }
     99
     100    public bool InitializeParametersRandomly {
     101      get { return InitParametersRandomlyParameter.Value.Value; }
     102      set { InitParametersRandomlyParameter.Value.Value = value; }
     103    }
    65104
    66105    [StorableConstructor]
     
    74113      Parameters.Add(new FixedValueParameter<StringValue>(ModelStructureParameterName, "The function for which the parameters must be fit (only numeric constants are tuned).", new StringValue("1.0 * x*x + 0.0")));
    75114      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "The maximum number of iterations for constants optimization.", new IntValue(200)));
    76     }
     115      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of independent random restarts (>0)", new IntValue(10)));
     116      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The PRNG seed value.", new IntValue()));
     117      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "Switch to determine if the random number seed should be initialized randomly.", new BoolValue(true)));
     118      Parameters.Add(new FixedValueParameter<BoolValue>(InitParamsRandomlyParameterName, "Switch to determine if the real-valued model parameters should be initialized randomly in each restart.", new BoolValue(false)));
     119
     120      SetParameterHiddenState();
     121
     122      InitParametersRandomlyParameter.Value.ValueChanged += (sender, args) => {
     123        SetParameterHiddenState();
     124      };
     125    }
     126
     127    private void SetParameterHiddenState() {
     128      var hide = !InitializeParametersRandomly;
     129      RestartsParameter.Hidden = hide;
     130      SeedParameter.Hidden = hide;
     131      SetSeedRandomlyParameter.Hidden = hide;
     132    }
     133
    77134    [StorableHook(HookType.AfterDeserialization)]
    78     private void AfterDeserialization() { }
     135    private void AfterDeserialization() {
     136      // BackwardsCompatibility3.3
     137      #region Backwards compatible code, remove with 3.4
     138      if (!Parameters.ContainsKey(RestartsParameterName))
     139        Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of independent random restarts", new IntValue(1)));
     140      if (!Parameters.ContainsKey(SeedParameterName))
     141        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The PRNG seed value.", new IntValue()));
     142      if (!Parameters.ContainsKey(SetSeedRandomlyParameterName))
     143        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "Switch to determine if the random number seed should be initialized randomly.", new BoolValue(true)));
     144      if (!Parameters.ContainsKey(InitParamsRandomlyParameterName))
     145        Parameters.Add(new FixedValueParameter<BoolValue>(InitParamsRandomlyParameterName, "Switch to determine if the numeric parameters of the model should be initialized randomly.", new BoolValue(false)));
     146
     147      SetParameterHiddenState();
     148      InitParametersRandomlyParameter.Value.ValueChanged += (sender, args) => {
     149        SetParameterHiddenState();
     150      };
     151      #endregion
     152    }
    79153
    80154    public override IDeepCloneable Clone(Cloner cloner) {
     
    84158    #region nonlinear regression
    85159    protected override void Run() {
    86       var solution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations);
    87       Results.Add(new Result(RegressionSolutionResultName, "The nonlinear regression solution.", solution));
    88       Results.Add(new Result("Root mean square error (train)", "The root of the mean of squared errors of the regression solution on the training set.", new DoubleValue(solution.TrainingRootMeanSquaredError)));
    89       Results.Add(new Result("Root mean square error (test)", "The root of the mean of squared errors of the regression solution on the test set.", new DoubleValue(solution.TestRootMeanSquaredError)));
    90     }
    91 
    92     public static ISymbolicRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, string modelStructure, int maxIterations) {
     160      IRegressionSolution bestSolution = null;
     161      if (InitializeParametersRandomly) {
     162        var qualityTable = new DataTable("RMSE table");
     163        qualityTable.VisualProperties.YAxisLogScale = true;
     164        var trainRMSERow = new DataRow("RMSE (train)");
     165        trainRMSERow.VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
     166        var testRMSERow = new DataRow("RMSE test");
     167        testRMSERow.VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
     168
     169        qualityTable.Rows.Add(trainRMSERow);
     170        qualityTable.Rows.Add(testRMSERow);
     171        Results.Add(new Result(qualityTable.Name, qualityTable.Name + " for all restarts", qualityTable));
     172        if (SetSeedRandomly) Seed = (new System.Random()).Next();
     173        var rand = new MersenneTwister((uint)Seed);
     174        bestSolution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations, rand);
     175        trainRMSERow.Values.Add(bestSolution.TrainingRootMeanSquaredError);
     176        testRMSERow.Values.Add(bestSolution.TestRootMeanSquaredError);
     177        for (int r = 0; r < Restarts; r++) {
     178          var solution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations, rand);
     179          trainRMSERow.Values.Add(solution.TrainingRootMeanSquaredError);
     180          testRMSERow.Values.Add(solution.TestRootMeanSquaredError);
     181          if (solution.TrainingRootMeanSquaredError < bestSolution.TrainingRootMeanSquaredError) {
     182            bestSolution = solution;
     183          }
     184        }
     185      } else {
     186        bestSolution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations);
     187      }
     188
     189      Results.Add(new Result(RegressionSolutionResultName, "The nonlinear regression solution.", bestSolution));
     190      Results.Add(new Result("Root mean square error (train)", "The root of the mean of squared errors of the regression solution on the training set.", new DoubleValue(bestSolution.TrainingRootMeanSquaredError)));
     191      Results.Add(new Result("Root mean square error (test)", "The root of the mean of squared errors of the regression solution on the test set.", new DoubleValue(bestSolution.TestRootMeanSquaredError)));
     192
     193    }
     194
     195    /// <summary>
     196    /// Fits a model to the data by optimizing the numeric constants.
     197    /// Model is specified as infix expression containing variable names and numbers.
     198    /// The starting point for the numeric constants is initialized randomly if a random number generator is specified (~N(0,1)). Otherwise the user specified constants are
     199    /// used as a starting point.
     200    /// </summary>-
     201    /// <param name="problemData">Training and test data</param>
     202    /// <param name="modelStructure">The function as infix expression</param>
     203    /// <param name="maxIterations">Number of constant optimization iterations (using Levenberg-Marquardt algorithm)</param>
     204    /// <param name="random">Optional random number generator for random initialization of numeric constants.</param>
     205    /// <returns></returns>
     206    public static ISymbolicRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, string modelStructure, int maxIterations, IRandom rand = null) {
    93207      var parser = new InfixExpressionParser();
    94208      var tree = parser.Parse(modelStructure);
    95       var simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
    96      
     209
    97210      if (!SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree)) throw new ArgumentException("The optimizer does not support the specified model structure.");
    98211
     212      // initialize constants randomly
     213      if (rand != null) {
     214        foreach (var node in tree.IterateNodesPrefix().OfType<ConstantTreeNode>()) {
     215          double f = Math.Exp(NormalDistributedRandom.NextDouble(rand, 0, 1));
     216          double s = rand.NextDouble() < 0.5 ? -1 : 1;
     217          node.Value = s * node.Value * f;
     218        }
     219      }
    99220      var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
    100       SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, problemData, problemData.TrainingIndices,
     221
     222      SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, problemData, problemData.TrainingIndices,
    101223        applyLinearScaling: false, maxIterations: maxIterations,
    102224        updateVariableWeights: false, updateConstantsInTree: true);
    103 
    104225
    105226      var scaledModel = new SymbolicRegressionModel(problemData.TargetVariable, tree, (ISymbolicDataAnalysisExpressionTreeInterpreter)interpreter.Clone());
Note: See TracChangeset for help on using the changeset viewer.