Free cookie consent management tool by TermsFeed Policy Generator

source: branches/MathNetNumerics-Exploration-2789/HeuristicLab.Algorithms.DataAnalysis.Experimental/GAM.cs @ 15433

Last change on this file since 15433 was 15433, checked in by gkronber, 6 years ago

#2789 added the possibility to include interaction terms in GAM

File size: 12.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Concurrent;
24using System.Collections.Generic;
25using System.Linq;
26using System.Threading;
27using System.Threading.Tasks;
28using HeuristicLab.Analysis;
29using HeuristicLab.Common;
30using HeuristicLab.Core;
31using HeuristicLab.Data;
32using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
33using HeuristicLab.Optimization;
34using HeuristicLab.Parameters;
35using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
36using HeuristicLab.Problems.DataAnalysis;
37using HeuristicLab.Problems.DataAnalysis.Symbolic;
38using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
39
40namespace HeuristicLab.Algorithms.DataAnalysis.Experimental {
41  // UNFINISHED
42  [Item("Generalized Additive Modelling", "GAM")]
43  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 102)]
44  [StorableClass]
45  public sealed class GAM : FixedDataAnalysisAlgorithm<IRegressionProblem> {
46    [StorableConstructor]
47    private GAM(bool deserializing) : base(deserializing) { }
48    [StorableHook(HookType.AfterDeserialization)]
49    private void AfterDeserialization() {     
50    }
51
52    private GAM(GAM original, Cloner cloner)
53      : base(original, cloner) {
54    }
55    public override IDeepCloneable Clone(Cloner cloner) {
56      return new GAM(this, cloner);
57    }
58
59    public GAM()
60      : base() {
61      Problem = new RegressionProblem();
62      Parameters.Add(new ValueParameter<DoubleValue>("Lambda", "Regularization for smoothing splines", new DoubleValue(1.0)));
63      Parameters.Add(new ValueParameter<IntValue>("Max iterations", "", new IntValue(100)));
64      Parameters.Add(new ValueParameter<IntValue>("Max interactions", "", new IntValue(1)));
65    }   
66
67    protected override void Run(CancellationToken cancellationToken) {
68
69      double lambda = ((IValueParameter<DoubleValue>)Parameters["Lambda"]).Value.Value;
70      int maxIters = ((IValueParameter<IntValue>)Parameters["Max iterations"]).Value.Value;
71      int maxInteractions = ((IValueParameter<IntValue>)Parameters["Max interactions"]).Value.Value;
72      if (maxInteractions < 1 || maxInteractions > 5) throw new ArgumentException("Max interactions is outside the valid range [1 .. 5]");
73
74      // calculates a GAM model using a linear representation + independent non-linear functions of each variable
75      // using backfitting algorithm (see The Elements of Statistical Learning page 298)
76
77      var problemData = Problem.ProblemData;
78      var y = problemData.TargetVariableTrainingValues.ToArray();
79      var avgY = y.Average();
80      var inputVars = Problem.ProblemData.AllowedInputVariables.ToArray();
81      var nTerms = inputVars.Length; // LR
82      for(int i=1;i<=maxInteractions;i++) {
83        nTerms += inputVars.Combinations(i).Count();
84      }
85      IRegressionModel[] f = new IRegressionModel[nTerms];
86      for(int i=0;i<f.Length;i++) {
87        f[i] = new ConstantModel(0.0, problemData.TargetVariable);
88      }
89
90      var rmseTable = new DataTable("RMSE");
91      var rmseRow = new DataRow("RMSE (train)");
92      var rmseRowTest = new DataRow("RMSE (test)");
93      rmseTable.Rows.Add(rmseRow);
94      rmseTable.Rows.Add(rmseRowTest);
95
96      Results.Add(new Result("RMSE", rmseTable));
97      rmseRow.Values.Add(CalculateResiduals(problemData, f, -1, avgY, problemData.TrainingIndices).StandardDeviation()); // -1 index to use all predictors
98      rmseRowTest.Values.Add(CalculateResiduals(problemData, f, -1, avgY, problemData.TestIndices).StandardDeviation());
99
100      // for analytics
101      double[] rss = new double[f.Length];
102      string[] terms = new string[f.Length];
103      Results.Add(new Result("RSS Values", typeof(DoubleMatrix)));
104
105      // until convergence
106      int iters = 0;
107      var t = new double[y.Length];
108      while (iters++ < maxIters) {
109        int j = 0;
110        foreach (var inputVar in inputVars) {
111          var res = CalculateResiduals(problemData, f, j, avgY, problemData.TrainingIndices);
112          rss[j] = res.Variance();
113          terms[j] = inputVar;
114          f[j] = RegressLR(problemData, inputVar, res);
115          j++;
116        }
117
118        for(int interaction = 1; interaction <= maxInteractions;interaction++) {
119          var selectedVars = HeuristicLab.Common.EnumerableExtensions.Combinations(inputVars, interaction);
120
121          foreach (var element in selectedVars) {
122            var res = CalculateResiduals(problemData, f, j, avgY, problemData.TrainingIndices);
123            rss[j] = res.Variance();
124            terms[j] = string.Format("f({0})", string.Join(",", element));
125            f[j] = RegressSpline(problemData, element.ToArray(), res, lambda);
126            j++;
127          }
128        }
129
130        rmseRow.Values.Add(CalculateResiduals(problemData, f, -1, avgY, problemData.TrainingIndices).StandardDeviation()); // -1 index to use all predictors
131        rmseRowTest.Values.Add(CalculateResiduals(problemData, f, -1, avgY, problemData.TestIndices).StandardDeviation());
132
133        // calculate table with residual contributions of each term
134        var rssTable = new DoubleMatrix(rss.Length, 1, new string[] { "RSS" }, terms);
135        for (int i = 0; i < rss.Length; i++) rssTable[i, 0] = rss[i];
136        Results["RSS Values"].Value = rssTable;
137
138        if (cancellationToken.IsCancellationRequested) break;
139      }
140
141      var model = new RegressionEnsembleModel(f.Concat(new[] { new ConstantModel(avgY, problemData.TargetVariable) }));
142      model.AverageModelEstimates = false;
143      Results.Add(new Result("Ensemble solution", model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone())));
144
145    }
146
147    private double[] CalculateResiduals(IRegressionProblemData problemData, IRegressionModel[] f, int j, double avgY, IEnumerable<int> rows) {
148      var y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
149      double[] t = y.Select(yi => yi - avgY).ToArray();
150      // collect other predictions
151      for (int k = 0; k < f.Length; k++) {
152        if (k != j) {
153          var pred = f[k].GetEstimatedValues(problemData.Dataset, rows).ToArray();
154          // determine target for this smoother
155          for (int i = 0; i < t.Length; i++) {
156            t[i] -= pred[i];
157          }
158        }
159      }
160      return t;
161    }
162
163    private IRegressionModel RegressLR(IRegressionProblemData problemData, string inputVar, double[] target) {
164      // Umständlich!
165      var ds = ((Dataset)problemData.Dataset).ToModifiable();
166      ds.ReplaceVariable(problemData.TargetVariable, target.Concat(Enumerable.Repeat(0.0, ds.Rows - target.Length)).ToList<double>());
167      var pd = new RegressionProblemData(ds, new string[] { inputVar }, problemData.TargetVariable);
168      pd.TrainingPartition.Start = problemData.TrainingPartition.Start;
169      pd.TrainingPartition.End = problemData.TrainingPartition.End;
170      pd.TestPartition.Start = problemData.TestPartition.Start;
171      pd.TestPartition.End = problemData.TestPartition.End;
172      double rmsError, cvRmsError;
173      return LinearRegression.CreateLinearRegressionSolution(pd, out rmsError, out cvRmsError).Model;
174    }
175
176    private IRegressionModel RegressSpline(IRegressionProblemData problemData, string inputVar, double[] target, double lambda) {
177      if (problemData.Dataset.VariableHasType<double>(inputVar)) {
178        // Umständlich!
179        return Splines.CalculatePenalizedRegressionSpline(
180          problemData.Dataset.GetDoubleValues(inputVar, problemData.TrainingIndices).ToArray(),
181          (double[])target.Clone(), lambda,
182          problemData.TargetVariable, new string[] { inputVar }
183          );
184      } else return new ConstantModel(target.Average(), problemData.TargetVariable);
185    }
186    private IRegressionModel RegressSpline(IRegressionProblemData problemData, string[] inputVars, double[] target, double lambda) {
187      if (inputVars.All(problemData.Dataset.VariableHasType<double>)) {
188        var product = problemData.Dataset.GetDoubleValues(inputVars.First(), problemData.TrainingIndices).ToArray();
189        for(int i = 1;i<inputVars.Length;i++) {
190          product = product.Zip(problemData.Dataset.GetDoubleValues(inputVars[i], problemData.TrainingIndices), (pi, vi) => pi * vi).ToArray();
191        }
192        // Umständlich!
193        return Splines.CalculatePenalizedRegressionSpline(
194          product,
195          (double[])target.Clone(), lambda,
196          problemData.TargetVariable, inputVars
197          );
198      } else return new ConstantModel(target.Average(), problemData.TargetVariable);
199    }
200
201    private IRegressionModel RegressRF(IRegressionProblemData problemData, string inputVar, double[] target, double lambda) {
202      if (problemData.Dataset.VariableHasType<double>(inputVar)) {
203        // Umständlich!
204        var ds = ((Dataset)problemData.Dataset).ToModifiable();
205        ds.ReplaceVariable(problemData.TargetVariable, target.Concat(Enumerable.Repeat(0.0, ds.Rows - target.Length)).ToList<double>());
206        var pd = new RegressionProblemData(ds, new string[] { inputVar }, problemData.TargetVariable);
207        pd.TrainingPartition.Start = problemData.TrainingPartition.Start;
208        pd.TrainingPartition.End = problemData.TrainingPartition.End;
209        pd.TestPartition.Start = problemData.TestPartition.Start;
210        pd.TestPartition.End = problemData.TestPartition.End;
211        double rmsError, oobRmsError;
212        double avgRelError, oobAvgRelError;
213        return RandomForestRegression.CreateRandomForestRegressionModel(pd, 100, 0.5, 0.5, 1234, out rmsError, out avgRelError, out oobRmsError, out oobAvgRelError);
214      } else return new ConstantModel(target.Average(), problemData.TargetVariable);
215    }
216  }
217
218
219  // UNFINISHED
220  public class RBFModel : NamedItem, IRegressionModel {
221    private alglib.rbfmodel model;
222
223    public string TargetVariable { get; set; }
224
225    public IEnumerable<string> VariablesUsedForPrediction { get; private set; }
226    private ITransformation<double>[] scaling;
227
228    public event EventHandler TargetVariableChanged;
229
230    public RBFModel(RBFModel orig, Cloner cloner) : base(orig, cloner) {
231      this.TargetVariable = orig.TargetVariable;
232      this.VariablesUsedForPrediction = orig.VariablesUsedForPrediction.ToArray();
233      this.model = (alglib.rbfmodel)orig.model.make_copy();
234      this.scaling = orig.scaling.Select(s => cloner.Clone(s)).ToArray();
235    }
236    public RBFModel(alglib.rbfmodel model, string targetVar, string[] inputs, IEnumerable<ITransformation<double>> scaling) : base("RBFModel", "RBFModel") {
237      this.model = model;
238      this.TargetVariable = targetVar;
239      this.VariablesUsedForPrediction = inputs;
240      this.scaling = scaling.ToArray();
241    }
242
243    public override IDeepCloneable Clone(Cloner cloner) {
244      return new RBFModel(this, cloner);
245    }
246
247    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
248      return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone());
249    }
250
251    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
252      double[] x = new double[VariablesUsedForPrediction.Count()];
253      double[] y;
254      foreach (var r in rows) {
255        int c = 0;
256        foreach (var v in VariablesUsedForPrediction) {
257          x[c] = scaling[c].Apply(dataset.GetDoubleValue(v, r).ToEnumerable()).First(); // OUCH!
258          c++;
259        }
260        alglib.rbfcalc(model, x, out y);
261        yield return y[0];
262      }
263    }
264  }
265}
Note: See TracBrowser for help on using the repository browser.