Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Algorithms.DataAnalysis.DecisionTrees/3.4/LeafModels/PreconstructedLinearModel.cs @ 17209

Last change on this file since 17209 was 17209, checked in by gkronber, 5 years ago

#2994: merged r17132:17198 from trunk to branch

File size: 6.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Problems.DataAnalysis;
27using HEAL.Attic;
28
29namespace HeuristicLab.Algorithms.DataAnalysis {
30  // multidimensional extension of http://www2.stat.duke.edu/~tjl13/s101/slides/unit6lec3H.pdf
31  [StorableType("15F2295C-28C1-48C3-8DCB-9470823C6734")]
32  internal sealed class PreconstructedLinearModel : RegressionModel {
33    [Storable]
34    public Dictionary<string, double> Coefficients { get; private set; }
35    [Storable]
36    public double Intercept { get; private set; }
37
38    public override IEnumerable<string> VariablesUsedForPrediction {
39      get { return Coefficients.Keys; }
40    }
41
42    #region HLConstructors
43    [StorableConstructor]
44    private PreconstructedLinearModel(StorableConstructorFlag _) : base(_) { }
45    private PreconstructedLinearModel(PreconstructedLinearModel original, Cloner cloner) : base(original, cloner) {
46      if (original.Coefficients != null) Coefficients = original.Coefficients.ToDictionary(x => x.Key, x => x.Value);
47      Intercept = original.Intercept;
48    }
49    public PreconstructedLinearModel(Dictionary<string, double> coefficients, double intercept, string targetvariable) : base(targetvariable) {
50      Coefficients = new Dictionary<string, double>(coefficients);
51      Intercept = intercept;
52    }
53    public PreconstructedLinearModel(double intercept, string targetvariable) : base(targetvariable) {
54      Coefficients = new Dictionary<string, double>();
55      Intercept = intercept;
56    }
57    public override IDeepCloneable Clone(Cloner cloner) {
58      return new PreconstructedLinearModel(this, cloner);
59    }
60    #endregion
61
62    public static PreconstructedLinearModel CreateLinearModel(IRegressionProblemData pd, out double rmse) {
63      return AlternativeCalculation(pd, out rmse);
64    }
65
66    private static PreconstructedLinearModel ClassicCalculation(IRegressionProblemData pd) {
67      var inputMatrix = pd.Dataset.ToArray(pd.AllowedInputVariables.Concat(new[] {
68        pd.TargetVariable
69      }), pd.AllIndices);
70
71      var nFeatures = inputMatrix.GetLength(1) - 1;
72      double[] coefficients;
73
74      alglib.linearmodel lm;
75      alglib.lrreport ar;
76      int retVal;
77      alglib.lrbuild(inputMatrix, inputMatrix.GetLength(0), nFeatures, out retVal, out lm, out ar);
78      if (retVal != 1) throw new ArgumentException("Error in calculation of linear regression solution");
79
80      alglib.lrunpack(lm, out coefficients, out nFeatures);
81      var coeffs = pd.AllowedInputVariables.Zip(coefficients, (s, d) => new {s, d}).ToDictionary(x => x.s, x => x.d);
82      var res = new PreconstructedLinearModel(coeffs, coefficients[nFeatures], pd.TargetVariable);
83      return res;
84    }
85
86    private static PreconstructedLinearModel AlternativeCalculation(IRegressionProblemData pd, out double rmse) {
87      var variables = pd.AllowedInputVariables.ToList();
88      var n = variables.Count;
89      var m = pd.TrainingIndices.Count();
90
91      //Set up X^T
92      var inTr = new double[n + 1, m];
93      for (var i = 0; i < n; i++) {
94        var vdata = pd.Dataset.GetDoubleValues(variables[i], pd.TrainingIndices).ToArray();
95        for (var j = 0; j < m; j++) inTr[i, j] = vdata[j];
96      }
97      for (var i = 0; i < m; i++) inTr[n, i] = 1;
98
99      //Set up y
100      var y = new double[m, 1];
101      var ydata = pd.TargetVariableTrainingValues.ToArray();
102      for (var i = 0; i < m; i++) y[i, 0] = ydata[i];
103
104      //Perform linear regression
105      var aTy = new double[n + 1, 1];
106      var aTa = new double[n + 1, n + 1];
107      var aTyVector = new double[n + 1];
108      int info;
109      alglib.densesolverreport report;
110      double[] coefficients;
111
112      //Perform linear regression
113      alglib.rmatrixgemm(n + 1, 1, m, 1, inTr, 0, 0, 0, y, 0, 0, 0, 0, ref aTy, 0, 0); //aTy = inTr * y;
114      alglib.rmatrixgemm(n + 1, n + 1, m, 1, inTr, 0, 0, 0, inTr, 0, 0, 1, 0, ref aTa, 0, 0); //aTa = inTr * t(inTr) +aTa //
115      alglib.spdmatrixcholesky(ref aTa, n + 1, true);
116      for (var i = 0; i < n + 1; i++) aTyVector[i] = aTy[i, 0];
117      alglib.spdmatrixcholeskysolve(aTa, n + 1, true, aTyVector, out info, out report, out coefficients);
118
119      //if Cholesky calculation fails fall back to classic linear regresseion
120      if (info != 1) {
121        alglib.linearmodel lm;
122        alglib.lrreport ar;
123        int retVal;
124        var inputMatrix = pd.Dataset.ToArray(pd.AllowedInputVariables.Concat(new[] {
125          pd.TargetVariable
126        }), pd.AllIndices);
127        alglib.lrbuild(inputMatrix, inputMatrix.GetLength(0), n, out retVal, out lm, out ar);
128        if (retVal != 1) throw new ArgumentException("Error in calculation of linear regression solution");
129        alglib.lrunpack(lm, out coefficients, out n);
130      }
131
132      var coeffs = Enumerable.Range(0, n).ToDictionary(i => variables[i], i => coefficients[i]);
133      var model = new PreconstructedLinearModel(coeffs, coefficients[n], pd.TargetVariable);
134      rmse = pd.TrainingIndices.Select(i => pd.Dataset.GetDoubleValue(pd.TargetVariable, i) - model.GetEstimatedValue(pd.Dataset, i)).Sum(r => r * r) / m;
135      rmse = Math.Sqrt(rmse);
136      return model;
137    }
138
139    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
140      return rows.Select(row => GetEstimatedValue(dataset, row));
141    }
142
143    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
144      return new RegressionSolution(this, problemData);
145    }
146
147    #region helpers
148    private double GetEstimatedValue(IDataset dataset, int row) {
149      return Intercept + (Coefficients.Count == 0 ? 0 : Coefficients.Sum(s => s.Value * dataset.GetDoubleValue(s.Key, row)));
150    }
151    #endregion
152  }
153}
Note: See TracBrowser for help on using the repository browser.