source: branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/NonlinearRegression/NonlinearRegression.cs @ 14251

Last change on this file since 14251 was 14251, checked in by gkronber, 5 years ago

#2650:

  • extended non-linear regression to work with factors
  • fixed bugs in constants optimizer and tree interpreter
  • improved simplification of factor variables
  • added support for factors to ERC view
  • added support for factors to solution comparison view
  • activated view for all factors
File size: 6.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Parameters;
29using HeuristicLab.Optimization;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32using HeuristicLab.Problems.DataAnalysis.Symbolic;
33using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
34
35namespace HeuristicLab.Algorithms.DataAnalysis {
36  /// <summary>
37  /// Nonlinear regression data analysis algorithm.
38  /// </summary>
39  [Item("Nonlinear Regression (NLR)", "Nonlinear regression (curve fitting) data analysis algorithm (wrapper for ALGLIB).")]
40  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
41  [StorableClass]
42  public sealed class NonlinearRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
43    private const string RegressionSolutionResultName = "Regression solution";
44    private const string ModelStructureParameterName = "Model structure";
45    private const string IterationsParameterName = "Iterations";
46
47    public IFixedValueParameter<StringValue> ModelStructureParameter {
48      get { return (IFixedValueParameter<StringValue>)Parameters[ModelStructureParameterName]; }
49    }
50    public IFixedValueParameter<IntValue> IterationsParameter {
51      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
52    }
53
54    public string ModelStructure {
55      get { return ModelStructureParameter.Value.Value; }
56      set { ModelStructureParameter.Value.Value = value; }
57    }
58
59    public int Iterations {
60      get { return IterationsParameter.Value.Value; }
61      set { IterationsParameter.Value.Value = value; }
62    }
63
64
65    [StorableConstructor]
66    private NonlinearRegression(bool deserializing) : base(deserializing) { }
67    private NonlinearRegression(NonlinearRegression original, Cloner cloner)
68      : base(original, cloner) {
69    }
70    public NonlinearRegression()
71      : base() {
72      Problem = new RegressionProblem();
73      Parameters.Add(new FixedValueParameter<StringValue>(ModelStructureParameterName, "The function for which the parameters must be fit (only numeric constants are tuned).", new StringValue("1.0 * x*x + 0.0")));
74      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "The maximum number of iterations for constants optimization.", new IntValue(200)));
75    }
76    [StorableHook(HookType.AfterDeserialization)]
77    private void AfterDeserialization() { }
78
79    public override IDeepCloneable Clone(Cloner cloner) {
80      return new NonlinearRegression(this, cloner);
81    }
82
83    #region nonlinear regression
84    protected override void Run() {
85      var solution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations);
86      Results.Add(new Result(RegressionSolutionResultName, "The nonlinear regression solution.", solution));
87      Results.Add(new Result("Root mean square error (train)", "The root of the mean of squared errors of the regression solution on the training set.", new DoubleValue(solution.TrainingRootMeanSquaredError)));
88      Results.Add(new Result("Root mean square error (test)", "The root of the mean of squared errors of the regression solution on the test set.", new DoubleValue(solution.TestRootMeanSquaredError)));
89    }
90
91    public static ISymbolicRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, string modelStructure, int maxIterations) {
92      var parser = new InfixExpressionParser();
93      var tree = parser.Parse(modelStructure);
94      // parser handles double and string variables equally by creating a VariableTreeNode
95      // post-process to replace VariableTreeNodes by FactorVariableTreeNodes for all string variables
96      var factorSymbol = new FactorVariable();
97      factorSymbol.VariableNames =
98        problemData.AllowedInputVariables.Where(name => problemData.Dataset.VariableHasType<string>(name));
99      factorSymbol.AllVariableNames = factorSymbol.VariableNames;
100      factorSymbol.VariableValues =
101        factorSymbol.VariableNames.Select(name => new KeyValuePair<string, List<string>>(name, problemData.Dataset.GetReadOnlyStringValues(name).Distinct().ToList()));
102
103      foreach (var parent in tree.IterateNodesPrefix().ToArray()) {
104        for (int i = 0; i < parent.SubtreeCount; i++) {
105          var child = parent.GetSubtree(i) as VariableTreeNode;
106          if (child != null && factorSymbol.VariableNames.Contains(child.VariableName)) {
107            parent.RemoveSubtree(i);
108            var factorTreeNode = (FactorVariableTreeNode)factorSymbol.CreateTreeNode();
109            factorTreeNode.VariableName = child.VariableName;
110            factorTreeNode.Weights =
111              factorTreeNode.Symbol.GetVariableValues(factorTreeNode.VariableName).Select(_ => 1.0).ToArray(); // weight = 1.0 for each value
112            parent.InsertSubtree(i, factorTreeNode);
113          }
114        }
115      }
116
117      if (!SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree)) throw new ArgumentException("The optimizer does not support the specified model structure.");
118
119      var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
120      SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, problemData, problemData.TrainingIndices,
121        applyLinearScaling: false, maxIterations: maxIterations,
122        updateVariableWeights: false, updateConstantsInTree: true);
123
124
125      var scaledModel = new SymbolicRegressionModel(problemData.TargetVariable, tree, (ISymbolicDataAnalysisExpressionTreeInterpreter)interpreter.Clone());
126      scaledModel.Scale(problemData);
127      SymbolicRegressionSolution solution = new SymbolicRegressionSolution(scaledModel, (IRegressionProblemData)problemData.Clone());
128      solution.Model.Name = "Regression Model";
129      solution.Name = "Regression Solution";
130      return solution;
131    }
132    #endregion
133  }
134}
Note: See TracBrowser for help on using the repository browser.