Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2789_MathNetNumerics-Exploration/HeuristicLab.Algorithms.DataAnalysis.Experimental/LinearRegressionCombinations.cs @ 16793

Last change on this file since 16793 was 14998, checked in by gkronber, 8 years ago

#2789 added forward selection algorithm and algorithm to calculate all LR combinations

File size: 11.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Concurrent;
24using System.Collections.Generic;
25using System.Linq;
26using System.Threading;
27using System.Threading.Tasks;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
32using HeuristicLab.Optimization;
33using HeuristicLab.Parameters;
34using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
35using HeuristicLab.Problems.DataAnalysis;
36using HeuristicLab.Problems.DataAnalysis.Symbolic;
37using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
38
39namespace HeuristicLab.Algorithms.DataAnalysis.Experimental {
40  /// <summary>
41  /// Linear regression data analysis algorithm.
42  /// </summary>
43  [Item("Linear Regression Combinations (LR)", "Calculates all possible LR solutions.")]
44  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 102)]
45  [StorableClass]
46  public sealed class LinearRegressionCombinations : FixedDataAnalysisAlgorithm<IRegressionProblem> {
47    public IFixedValueParameter<IntValue> MaximumInputsParameter {
48      get { return (IFixedValueParameter<IntValue>)Parameters["Maximum Inputs"]; }
49    }
50    public int MaximumInputs {
51      get { return MaximumInputsParameter.Value.Value; }
52      set { MaximumInputsParameter.Value.Value = value; }
53    }
54
55    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
56      get { return (IFixedValueParameter<BoolValue>)Parameters["Create Solution"]; }
57    }
58    public bool CreateSolution {
59      get { return CreateSolutionParameter.Value.Value; }
60      set { CreateSolutionParameter.Value.Value = value; }
61    }
62
63    public IFixedValueParameter<IntValue> MaximumSolutionsParameter {
64      get { return (IFixedValueParameter<IntValue>)Parameters["Maximum Solutions stored"]; }
65    }
66    public int MaximumSolutions {
67      get { return MaximumSolutionsParameter.Value.Value; }
68      set { MaximumSolutionsParameter.Value.Value = value; }
69    }
70
71    private IntValue CalculatedModelsResults {
72      get {
73        if (!Results.ContainsKey("Calculated Models")) Results.Add(new Result("Calculated Models", "The number of calculated linear models ", new IntValue(0)));
74        return (IntValue)Results["Calculated Models"].Value;
75      }
76    }
77    public int CalculatedModels {
78      get { return CalculatedModelsResults.Value; }
79      set { CalculatedModelsResults.Value = value; }
80    }
81
82    private IntValue TotalModelsResult {
83      get {
84        if (!Results.ContainsKey("Total Models")) Results.Add(new Result("Total Models", "The total number of linear models to calculate", new IntValue(0)));
85        return (IntValue)Results["Total Models"].Value;
86      }
87    }
88    public int TotalModels {
89      get { return TotalModelsResult.Value; }
90      set { TotalModelsResult.Value = value; }
91    }
92
93    private IntValue CalculatedInputResults {
94      get {
95        if (!Results.ContainsKey("Calculated Inputs")) Results.Add(new Result("Calculated Inputs", "The maximum of already calculated input combinations.", new IntValue(0)));
96        return (IntValue)Results["Calculated Inputs"].Value;
97      }
98    }
99    public int CalculatedInputs {
100      get { return CalculatedInputResults.Value; }
101      set { CalculatedInputResults.Value = value; }
102    }
103
104    [StorableConstructor]
105    private LinearRegressionCombinations(bool deserializing) : base(deserializing) { }
106    [StorableHook(HookType.AfterDeserialization)]
107    private void AfterDeserialization() {
108      RegisterEventHandlers();
109    }
110
111    private LinearRegressionCombinations(LinearRegressionCombinations original, Cloner cloner)
112      : base(original, cloner) {
113      RegisterEventHandlers();
114    }
115    public override IDeepCloneable Clone(Cloner cloner) {
116      return new LinearRegressionCombinations(this, cloner);
117    }
118
119    public LinearRegressionCombinations()
120      : base() {
121      Parameters.Add(new FixedValueParameter<IntValue>("Maximum Inputs", "The maximum number of input variables used in the linear models.", new IntValue(1)));
122      Parameters.Add(new FixedValueParameter<IntValue>("Maximum Solutions stored", "The maximum number of solutions that are stored per number of inputs.", new IntValue(1000)));
123      Parameters.Add(new FixedValueParameter<BoolValue>("Create Solution", "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(false)));
124
125      Problem = new RegressionProblem();
126      RegisterEventHandlers();
127    }
128
129    private void RegisterEventHandlers() {
130      Problem.ProblemDataChanged += (o, e) => { MaximumInputs = Problem.ProblemData.InputVariables.CheckedItems.Count(); };
131    }
132    protected override void OnProblemChanged() {
133      base.OnProblemChanged();
134      MaximumInputs = Problem.ProblemData.InputVariables.CheckedItems.Count();
135    }
136
137
138    private static long CalculateCombinations(int maximumInputs, int totalVariables) {
139      long combinations = 0;
140
141      for (int i = 1; i <= maximumInputs; i++)
142        combinations += Common.EnumerableExtensions.BinomialCoefficient(totalVariables, i);
143      return combinations;
144    }
145
146    protected override void Run(CancellationToken cancellationToken) {
147      double[,] inputMatrix = Problem.ProblemData.Dataset.ToArray(Problem.ProblemData.AllowedInputVariables.Concat(new string[] { Problem.ProblemData.TargetVariable }), Problem.ProblemData.TrainingIndices);
148      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
149        throw new NotSupportedException("Linear regression does not support NaN or infinity values in the input dataset.");
150
151      var templateProblemData = (IRegressionProblemData)Problem.ProblemData.Clone();
152      foreach (var variable in templateProblemData.InputVariables)
153        templateProblemData.InputVariables.SetItemCheckedState(variable, false);
154
155      var inputVariableNames = Problem.ProblemData.InputVariables.CheckedItems.Select(i => i.Value.Value).ToList();
156      var createSolution = CreateSolution;
157      var maximumInputs = MaximumInputs;
158      var maximumSolutions = MaximumSolutions;
159
160      var storedRuns = new List<IRun>[maximumInputs];
161      var runs = new ConcurrentBag<IRun>();
162
163      TotalModels = (int)CalculateCombinations(MaximumInputs, inputVariableNames.Count);
164      CalculatedModels = 0;
165      CalculatedInputs = 0;
166
167      for (int inputs = 1; inputs <= MaximumInputs; inputs++) {
168        Parallel.ForEach(inputVariableNames.Combinations(inputs).ToList(), inputCombination => {
169          var problemData = new RegressionProblemData(templateProblemData.Dataset, inputCombination, templateProblemData.TargetVariable);
170          problemData.TrainingPartition.Start = templateProblemData.TrainingPartition.Start;
171          problemData.TrainingPartition.End = templateProblemData.TrainingPartition.End;
172          problemData.TestPartition.Start = templateProblemData.TestPartition.Start;
173          problemData.TestPartition.End = templateProblemData.TestPartition.End;
174
175          double trainRmsError, testRmsError;
176          var solution = CreateLinearRegressionSolution(problemData, createSolution, out trainRmsError, out testRmsError);
177
178          var run = new Run();
179          run.Name = string.Format("Run - Inputs {0}/{1}", inputCombination.Count(), MaximumInputs);
180          if (solution != null) run.Results.Add("Solution", solution);
181          run.Results.Add("RMSE train", new DoubleValue(trainRmsError));
182          run.Results.Add("RMSE test", new DoubleValue(testRmsError));
183          run.Results.Add("Inputs", new IntValue(inputCombination.Count()));
184          run.Results.Add("Input names", new StringValue(string.Join(" ", inputCombination)));
185          runs.Add(run);
186        });
187
188        CalculatedModels += runs.Count;
189        CalculatedInputs = inputs;
190        storedRuns[inputs - 1] = runs.OrderBy(r => ((DoubleValue)r.Results["RMSE test"]).Value).Take(maximumSolutions).ToList();
191        runs = new ConcurrentBag<IRun>();
192
193        if (cancellationToken.IsCancellationRequested) {
194          Results.Add(new Result("Runs", new RunCollection(storedRuns.SelectMany(r => r))));
195          cancellationToken.ThrowIfCancellationRequested();
196        }
197      }
198
199
200      Results.Add(new Result("Runs", new RunCollection(storedRuns.SelectMany(r => r))));
201    }
202
203
204    public static ISymbolicRegressionSolution CreateLinearRegressionSolution(IRegressionProblemData problemData, bool buildSolution, out double trainRmsError, out double testRmsError) {
205      var dataset = problemData.Dataset;
206      string targetVariable = problemData.TargetVariable;
207      IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
208
209      double[,] inputMatrix = dataset.ToArray(allowedInputVariables.Concat(new string[] { targetVariable }), problemData.TrainingIndices);
210      double[,] testMatrix = dataset.ToArray(allowedInputVariables.Concat(new string[] { targetVariable }), problemData.TestIndices);
211
212      alglib.linearmodel lm = new alglib.linearmodel();
213      alglib.lrreport ar = new alglib.lrreport();
214      int nRows = inputMatrix.GetLength(0);
215      int nFeatures = inputMatrix.GetLength(1) - 1;
216      double[] coefficients = new double[nFeatures + 1]; // last coefficient is for the constant
217
218      int retVal = 1;
219      alglib.lrbuild(inputMatrix, nRows, nFeatures, out retVal, out lm, out ar);
220      if (retVal != 1) throw new ArgumentException("Error in calculation of linear regression solution");
221      trainRmsError = ar.rmserror;
222
223      alglib.lrunpack(lm, out coefficients, out nFeatures);
224      testRmsError = alglib.lrrmserror(lm, testMatrix, testMatrix.GetLength(0));
225
226      if (!buildSolution) return null;
227
228      ISymbolicExpressionTree tree = new SymbolicExpressionTree(new ProgramRootSymbol().CreateTreeNode());
229      ISymbolicExpressionTreeNode startNode = new StartSymbol().CreateTreeNode();
230      tree.Root.AddSubtree(startNode);
231      ISymbolicExpressionTreeNode addition = new Addition().CreateTreeNode();
232      startNode.AddSubtree(addition);
233
234      int col = 0;
235      foreach (string column in allowedInputVariables) {
236        VariableTreeNode vNode = (VariableTreeNode)new HeuristicLab.Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
237        vNode.VariableName = column;
238        vNode.Weight = coefficients[col];
239        addition.AddSubtree(vNode);
240        col++;
241      }
242
243      ConstantTreeNode cNode = (ConstantTreeNode)new Constant().CreateTreeNode();
244      cNode.Value = coefficients[coefficients.Length - 1];
245      addition.AddSubtree(cNode);
246
247      SymbolicRegressionSolution solution = new SymbolicRegressionSolution(new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeInterpreter()), problemData);
248      solution.Model.Name = "Linear Regression Model";
249      solution.Name = "Linear Regression Solution";
250      return solution;
251    }
252  }
253}
Note: See TracBrowser for help on using the repository browser.