source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/LinearRegressionModel.cs @ 17074

Last change on this file since 17074 was 17074, checked in by abeham, 22 months ago

#2892: merged to stable

File size: 5.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30
31namespace HeuristicLab.Algorithms.DataAnalysis {
32  /// <summary>
33  /// Represents a linear regression model
34  /// </summary>
35  [StorableClass]
36  [Item("Linear Regression Model", "Represents a linear regression model.")]
37  public sealed class LinearRegressionModel : RegressionModel, IConfidenceRegressionModel {
38    public static new Image StaticItemImage {
39      get { return HeuristicLab.Common.Resources.VSImageLibrary.Function; }
40    }
41
42    [Storable]
43    public double[,] C {
44      get; private set;
45    }
46    [Storable]
47    public double[] W {
48      get; private set;
49    }
50
51    [Storable]
52    public double NoiseSigma {
53      get; private set;
54    }
55
56    public override IEnumerable<string> VariablesUsedForPrediction {
57      get { return doubleVariables.Union(factorVariables.Select(f => f.Key)); }
58    }
59
60    [Storable]
61    private string[] doubleVariables;
62    [Storable]
63    private List<KeyValuePair<string, IEnumerable<string>>> factorVariables;
64
65    /// <summary>
66    /// Enumerable of variable names used by the model including one-hot-encoded of factor variables.
67    /// </summary>
68    public IEnumerable<string> ParameterNames {
69      get {
70        return factorVariables.SelectMany(kvp => kvp.Value.Select(factorVal => $"{kvp.Key}={factorVal}"))
71          .Concat(doubleVariables)
72          .Concat(new[] { "<const>" });
73      }
74    }
75
76    [StorableConstructor]
77    private LinearRegressionModel(bool deserializing)
78      : base(deserializing) {
79    }
80    private LinearRegressionModel(LinearRegressionModel original, Cloner cloner)
81      : base(original, cloner) {
82      this.W = original.W;
83      this.C = original.C;
84      this.NoiseSigma = original.NoiseSigma;
85
86      doubleVariables = (string[])original.doubleVariables.Clone();
87      this.factorVariables = original.factorVariables.Select(kvp => new KeyValuePair<string, IEnumerable<string>>(kvp.Key, new List<string>(kvp.Value))).ToList();
88    }
89    public LinearRegressionModel(double[] w, double[,] covariance, double noiseSigma, string targetVariable, IEnumerable<string> doubleInputVariables, IEnumerable<KeyValuePair<string, IEnumerable<string>>> factorVariables)
90      : base(targetVariable) {
91      this.name = ItemName;
92      this.description = ItemDescription;
93      this.W = new double[w.Length];
94      Array.Copy(w, W, w.Length);
95      this.C = new double[covariance.GetLength(0), covariance.GetLength(1)];
96      Array.Copy(covariance, C, covariance.Length);
97      this.NoiseSigma = noiseSigma;
98      this.doubleVariables = doubleInputVariables.ToArray();
99      // clone
100      this.factorVariables = factorVariables.Select(kvp => new KeyValuePair<string, IEnumerable<string>>(kvp.Key, new List<string>(kvp.Value))).ToList();
101    }
102
103    [StorableHook(HookType.AfterDeserialization)]
104    private void AfterDeserialization() {
105    }
106
107    public override IDeepCloneable Clone(Cloner cloner) {
108      return new LinearRegressionModel(this, cloner);
109    }
110
111    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
112      double[,] inputData = dataset.ToArray(doubleVariables, rows);
113      double[,] factorData = dataset.ToArray(factorVariables, rows);
114
115      inputData = factorData.HorzCat(inputData);
116
117      int n = inputData.GetLength(0);
118      int columns = inputData.GetLength(1);
119
120      for (int row = 0; row < n; row++) {
121        double p = 0.0;
122        for (int column = 0; column < columns; column++) {
123          p += W[column] * inputData[row, column];
124        }
125        p += W[columns];
126        yield return p;
127      }
128    }
129
130    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
131      double[,] inputData = dataset.ToArray(doubleVariables, rows);
132      double[,] factorData = dataset.ToArray(factorVariables, rows);
133
134      inputData = factorData.HorzCat(inputData);
135
136      int n = inputData.GetLength(0);
137      int columns = inputData.GetLength(1);
138
139      double[] d = new double[C.GetLength(0)];
140
141      for (int row = 0; row < n; row++) {
142        for (int column = 0; column < columns; column++) {
143          d[column] = inputData[row, column];
144        }
145        d[columns] = 1;
146
147        double var = 0.0;
148        for (int i = 0; i < d.Length; i++) {
149          for (int j = 0; j < d.Length; j++) {
150            var += d[i] * C[i, j] * d[j];
151          }
152        }
153        yield return var + NoiseSigma * NoiseSigma;
154      }
155    }
156
157    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
158      return new ConfidenceRegressionSolution(this, new RegressionProblemData(problemData));
159    }
160  }
161}
Note: See TracBrowser for help on using the repository browser.