source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NonlinearRegression/NonlinearRegression.cs @ 14316

Last change on this file since 14316 was 14316, checked in by mkommend, 5 years ago

#2657: Changed nonlinear regression to perform at least one optimization.

File size: 9.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Optimization;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Problems.DataAnalysis;
31using HeuristicLab.Problems.DataAnalysis.Symbolic;
32using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
33using HeuristicLab.Random;
34
35namespace HeuristicLab.Algorithms.DataAnalysis {
36  /// <summary>
37  /// Nonlinear regression data analysis algorithm.
38  /// </summary>
39  [Item("Nonlinear Regression (NLR)", "Nonlinear regression (curve fitting) data analysis algorithm (wrapper for ALGLIB).")]
40  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
41  [StorableClass]
42  public sealed class NonlinearRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
43    private const string RegressionSolutionResultName = "Regression solution";
44    private const string ModelStructureParameterName = "Model structure";
45    private const string IterationsParameterName = "Iterations";
46    private const string RestartsParameterName = "Restarts";
47    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
48    private const string SeedParameterName = "Seed";
49
50    public IFixedValueParameter<StringValue> ModelStructureParameter {
51      get { return (IFixedValueParameter<StringValue>)Parameters[ModelStructureParameterName]; }
52    }
53    public IFixedValueParameter<IntValue> IterationsParameter {
54      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
55    }
56
57    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
58      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
59    }
60
61    public IFixedValueParameter<IntValue> SeedParameter {
62      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
63    }
64
65    public IFixedValueParameter<IntValue> RestartsParameter {
66      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
67    }
68
69    public string ModelStructure {
70      get { return ModelStructureParameter.Value.Value; }
71      set { ModelStructureParameter.Value.Value = value; }
72    }
73
74    public int Iterations {
75      get { return IterationsParameter.Value.Value; }
76      set { IterationsParameter.Value.Value = value; }
77    }
78
79    public int Restarts {
80      get { return RestartsParameter.Value.Value; }
81      set { RestartsParameter.Value.Value = value; }
82    }
83
84    public int Seed {
85      get { return SeedParameter.Value.Value; }
86      set { SeedParameter.Value.Value = value; }
87    }
88
89    public bool SetSeedRandomly {
90      get { return SetSeedRandomlyParameter.Value.Value; }
91      set { SetSeedRandomlyParameter.Value.Value = value; }
92    }
93
94    [StorableConstructor]
95    private NonlinearRegression(bool deserializing) : base(deserializing) { }
96    private NonlinearRegression(NonlinearRegression original, Cloner cloner)
97      : base(original, cloner) {
98    }
99    public NonlinearRegression()
100      : base() {
101      Problem = new RegressionProblem();
102      Parameters.Add(new FixedValueParameter<StringValue>(ModelStructureParameterName, "The function for which the parameters must be fit (only numeric constants are tuned).", new StringValue("1.0 * x*x + 0.0")));
103      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "The maximum number of iterations for constants optimization.", new IntValue(200)));
104      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of independent random restarts", new IntValue(10)));
105      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The PRNG seed value.", new IntValue()));
106      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "Switch to determine if the random number seed should be initialized randomly.", new BoolValue(true)));
107    }
108
109    [StorableHook(HookType.AfterDeserialization)]
110    private void AfterDeserialization() {
111      // BackwardsCompatibility3.3
112      #region Backwards compatible code, remove with 3.4
113      if (!Parameters.ContainsKey(RestartsParameterName))
114        Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of independent random restarts", new IntValue(1)));
115      if (!Parameters.ContainsKey(SeedParameterName))
116        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The PRNG seed value.", new IntValue()));
117      if (!Parameters.ContainsKey(SetSeedRandomlyParameterName))
118        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "Switch to determine if the random number seed should be initialized randomly.", new BoolValue(true)));
119      #endregion
120    }
121
122    public override IDeepCloneable Clone(Cloner cloner) {
123      return new NonlinearRegression(this, cloner);
124    }
125
126    #region nonlinear regression
127    protected override void Run() {
128      if (SetSeedRandomly) Seed = (new System.Random()).Next();
129      var rand = new MersenneTwister((uint)Seed);
130
131      var bestSolution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations, rand);
132      for (int r = 0; r < Restarts; r++) {
133        var solution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations, rand);
134        if (solution.TrainingRootMeanSquaredError < bestSolution.TrainingRootMeanSquaredError) {
135          bestSolution = solution;
136        }
137      }
138
139      Results.Add(new Result(RegressionSolutionResultName, "The nonlinear regression solution.", bestSolution));
140      Results.Add(new Result("Root mean square error (train)", "The root of the mean of squared errors of the regression solution on the training set.", new DoubleValue(bestSolution.TrainingRootMeanSquaredError)));
141      Results.Add(new Result("Root mean square error (test)", "The root of the mean of squared errors of the regression solution on the test set.", new DoubleValue(bestSolution.TestRootMeanSquaredError)));
142
143    }
144
145    /// <summary>
146    /// Fits a model to the data by optimizing the numeric constants.
147    /// Model is specified as infix expression containing variable names and numbers.
148    /// The starting point for the numeric constants is initialized randomly if a random number generator is specified (~N(0,1)). Otherwise the user specified constants are
149    /// used as a starting point.
150    /// </summary>-
151    /// <param name="problemData">Training and test data</param>
152    /// <param name="modelStructure">The function as infix expression</param>
153    /// <param name="maxIterations">Number of constant optimization iterations (using Levenberg-Marquardt algorithm)</param>
154    /// <param name="random">Optional random number generator for random initialization of numeric constants.</param>
155    /// <returns></returns>
156    public static ISymbolicRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, string modelStructure, int maxIterations, IRandom random = null) {
157      var parser = new InfixExpressionParser();
158      var tree = parser.Parse(modelStructure);
159
160      if (!SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree)) throw new ArgumentException("The optimizer does not support the specified model structure.");
161
162      // initialize constants randomly
163      if (random != null) {
164        foreach (var node in tree.IterateNodesPrefix().OfType<ConstantTreeNode>()) {
165          node.Value = NormalDistributedRandom.NextDouble(random, 0, 1);
166        }
167      }
168      var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
169
170      SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, problemData, problemData.TrainingIndices,
171        applyLinearScaling: false, maxIterations: maxIterations,
172        updateVariableWeights: false, updateConstantsInTree: true);
173
174      var scaledModel = new SymbolicRegressionModel(problemData.TargetVariable, tree, (ISymbolicDataAnalysisExpressionTreeInterpreter)interpreter.Clone());
175      scaledModel.Scale(problemData);
176      SymbolicRegressionSolution solution = new SymbolicRegressionSolution(scaledModel, (IRegressionProblemData)problemData.Clone());
177      solution.Model.Name = "Regression Model";
178      solution.Name = "Regression Solution";
179      return solution;
180    }
181    #endregion
182  }
183}
Note: See TracBrowser for help on using the repository browser.