source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NonlinearRegression/NonlinearRegression.cs @ 14024

Last change on this file since 14024 was 14024, checked in by gkronber, 6 years ago

#2627: added first implementation of nonlinear regression algorithm + formatter and parser for infix expressions

File size: 5.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Parameters;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Optimization;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Problems.DataAnalysis;
33using HeuristicLab.Problems.DataAnalysis.Symbolic;
34using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
35
36namespace HeuristicLab.Algorithms.DataAnalysis {
37  /// <summary>
38  /// Nonlinear regression data analysis algorithm.
39  /// </summary>
40  [Item("Nonlinear Regression (NLR)", "Nonlinear regression (curve fitting) data analysis algorithm (wrapper for ALGLIB).")]
41  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
42  [StorableClass]
43  public sealed class NonlinearRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
44    private const string LinearRegressionModelResultName = "Regression solution";
45    private const string ModelStructureParameterName = "Model structure";
46    private const string IterationsParameterName = "Iterations";
47
48    public IFixedValueParameter<StringValue> ModelStructureParameter {
49      get { return (IFixedValueParameter<StringValue>)Parameters[ModelStructureParameterName]; }
50    }
51    public IFixedValueParameter<IntValue> IterationsParameter {
52      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
53    }
54
55    public string ModelStructure {
56      get { return ModelStructureParameter.Value.Value; }
57      set { ModelStructureParameter.Value.Value = value; }
58    }
59
60    public int Iterations {
61      get { return IterationsParameter.Value.Value; }
62      set { IterationsParameter.Value.Value = value; }
63    }
64
65
66    [StorableConstructor]
67    private NonlinearRegression(bool deserializing) : base(deserializing) { }
68    private NonlinearRegression(NonlinearRegression original, Cloner cloner)
69      : base(original, cloner) {
70    }
71    public NonlinearRegression()
72      : base() {
73      Problem = new RegressionProblem();
74      Parameters.Add(new FixedValueParameter<StringValue>(ModelStructureParameterName, "The function for which the parameters must be fit (only numeric constants are tuned).", new StringValue("1.0 * x*x + 0.0")));
75      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "The maximum number of iterations for constants optimization.", new IntValue(200)));
76    }
77    [StorableHook(HookType.AfterDeserialization)]
78    private void AfterDeserialization() { }
79
80    public override IDeepCloneable Clone(Cloner cloner) {
81      return new NonlinearRegression(this, cloner);
82    }
83
84    #region nonlinear regression
85    protected override void Run() {
86      var solution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations);
87      Results.Add(new Result(LinearRegressionModelResultName, "The nonlinear regression solution.", solution));
88      Results.Add(new Result("Root mean square error (train)", "The root of the mean of squared errors of the regression solution on the training set.", new DoubleValue(solution.TrainingRootMeanSquaredError)));
89      Results.Add(new Result("Root mean square error (test)", "The root of the mean of squared errors of the regression solution on the test set.", new DoubleValue(solution.TestRootMeanSquaredError)));
90    }
91
92    public static ISymbolicRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, string modelStructure, int maxIterations) {
93      var parser = new InfixExpressionParser();
94      var tree = parser.Parse(modelStructure);
95      var simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
96     
97      if (!SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree)) throw new ArgumentException("The optimizer does not support the specified model structure.");
98
99      var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
100      SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, problemData, problemData.TrainingIndices,
101        applyLinearScaling: false, maxIterations: maxIterations,
102        updateVariableWeights: false, updateConstantsInTree: true);
103
104
105      var scaledModel = new SymbolicRegressionModel(problemData.TargetVariable, tree, (ISymbolicDataAnalysisExpressionTreeInterpreter)interpreter.Clone());
106      scaledModel.Scale(problemData);
107      SymbolicRegressionSolution solution = new SymbolicRegressionSolution(scaledModel, (IRegressionProblemData)problemData.Clone());
108      solution.Model.Name = "Regression Model";
109      solution.Name = "Regression Solution";
110      return solution;
111    }
112    #endregion
113  }
114}
Note: See TracBrowser for help on using the repository browser.