Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/LinearRegressionModel.cs @ 18242

Last change on this file since 18242 was 17226, checked in by mkommend, 5 years ago

#2521: Merged trunk changes into problem refactoring branch.

File size: 5.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HEAL.Attic;
29using HeuristicLab.Problems.DataAnalysis;
30
31namespace HeuristicLab.Algorithms.DataAnalysis {
32  /// <summary>
33  /// Represents a linear regression model
34  /// </summary>
35  [StorableType("B65FB0CA-7333-41FE-8156-FF141C54F5AF")]
36  [Item("Linear Regression Model", "Represents a linear regression model.")]
37  public sealed class LinearRegressionModel : RegressionModel, IConfidenceRegressionModel {
38    public static new Image StaticItemImage {
39      get { return HeuristicLab.Common.Resources.VSImageLibrary.Function; }
40    }
41
42    [Storable]
43    public double[,] C {
44      get; private set;
45    }
46    [Storable]
47    public double[] W {
48      get; private set;
49    }
50
51    [Storable]
52    public double NoiseSigma {
53      get; private set;
54    }
55
56    public override IEnumerable<string> VariablesUsedForPrediction {
57      get { return doubleVariables.Union(factorVariables.Select(f => f.Key)); }
58    }
59
60    [Storable]
61    private string[] doubleVariables;
62    [Storable]
63    private List<KeyValuePair<string, IEnumerable<string>>> factorVariables;
64
65    /// <summary>
66    /// Enumerable of variable names used by the model including one-hot-encoded of factor variables.
67    /// </summary>
68    public IEnumerable<string> ParameterNames {
69      get {
70        return factorVariables.SelectMany(kvp => kvp.Value.Select(factorVal => $"{kvp.Key}={factorVal}"))
71          .Concat(doubleVariables)
72          .Concat(new[] { "<const>" });
73      }
74    }
75
76    [StorableConstructor]
77    private LinearRegressionModel(StorableConstructorFlag _) : base(_) {
78    }
79    private LinearRegressionModel(LinearRegressionModel original, Cloner cloner)
80      : base(original, cloner) {
81      this.W = original.W;
82      this.C = original.C;
83      this.NoiseSigma = original.NoiseSigma;
84
85      doubleVariables = (string[])original.doubleVariables.Clone();
86      this.factorVariables = original.factorVariables.Select(kvp => new KeyValuePair<string, IEnumerable<string>>(kvp.Key, new List<string>(kvp.Value))).ToList();
87    }
88    public LinearRegressionModel(double[] w, double[,] covariance, double noiseSigma, string targetVariable, IEnumerable<string> doubleInputVariables, IEnumerable<KeyValuePair<string, IEnumerable<string>>> factorVariables)
89      : base(targetVariable) {
90      this.name = ItemName;
91      this.description = ItemDescription;
92      this.W = new double[w.Length];
93      Array.Copy(w, W, w.Length);
94      this.C = new double[covariance.GetLength(0), covariance.GetLength(1)];
95      Array.Copy(covariance, C, covariance.Length);
96      this.NoiseSigma = noiseSigma;
97      this.doubleVariables = doubleInputVariables.ToArray();
98      // clone
99      this.factorVariables = factorVariables.Select(kvp => new KeyValuePair<string, IEnumerable<string>>(kvp.Key, new List<string>(kvp.Value))).ToList();
100    }
101
102    [StorableHook(HookType.AfterDeserialization)]
103    private void AfterDeserialization() {
104    }
105
106    public override IDeepCloneable Clone(Cloner cloner) {
107      return new LinearRegressionModel(this, cloner);
108    }
109
110    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
111      double[,] inputData = dataset.ToArray(doubleVariables, rows);
112      double[,] factorData = dataset.ToArray(factorVariables, rows);
113
114      inputData = factorData.HorzCat(inputData);
115
116      int n = inputData.GetLength(0);
117      int columns = inputData.GetLength(1);
118
119      for (int row = 0; row < n; row++) {
120        double p = 0.0;
121        for (int column = 0; column < columns; column++) {
122          p += W[column] * inputData[row, column];
123        }
124        p += W[columns];
125        yield return p;
126      }
127    }
128
129    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
130      double[,] inputData = dataset.ToArray(doubleVariables, rows);
131      double[,] factorData = dataset.ToArray(factorVariables, rows);
132
133      inputData = factorData.HorzCat(inputData);
134
135      int n = inputData.GetLength(0);
136      int columns = inputData.GetLength(1);
137
138      double[] d = new double[C.GetLength(0)];
139
140      for (int row = 0; row < n; row++) {
141        for (int column = 0; column < columns; column++) {
142          d[column] = inputData[row, column];
143        }
144        d[columns] = 1;
145
146        double var = 0.0;
147        for (int i = 0; i < d.Length; i++) {
148          for (int j = 0; j < d.Length; j++) {
149            var += d[i] * C[i, j] * d[j];
150          }
151        }
152        yield return var + NoiseSigma * NoiseSigma;
153      }
154    }
155
156    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
157      return new ConfidenceRegressionSolution(this, new RegressionProblemData(problemData));
158    }
159  }
160}
Note: See TracBrowser for help on using the repository browser.