Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Algorithms.DataAnalysis.DecisionTrees/3.4/Splitting/UnivariateOnlineLR.cs @ 18183

Last change on this file since 18183 was 17209, checked in by gkronber, 5 years ago

#2994: merged r17132:17198 from trunk to branch

File size: 4.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26
27namespace HeuristicLab.Algorithms.DataAnalysis {
28  /// <summary>
29  /// Helper class for incremental split calculation.
30  /// Used while moving a potential splitter along the ordered training instances
31  /// </summary>
32  internal class UnivariateOnlineLR {
33    #region state
34    private readonly NeumaierSum targetMean;
35    private readonly NeumaierSum attributeMean;
36    private readonly NeumaierSum targetVarSum;
37    private readonly NeumaierSum attributeVarSum;
38    private readonly NeumaierSum comoment;
39    private readonly NeumaierSum ssr;
40    private int size;
41    #endregion
42
43    public double Ssr {
44      get { return ssr.Get(); }
45    }
46    public int Size {
47      get { return size; }
48    }
49
50    private double Beta {
51      get { return comoment.Get() / attributeVarSum.Get(); }
52    }
53    private double Alpha {
54      get { return targetMean.Get() - Beta * attributeMean.Get(); }
55    }
56
57    public UnivariateOnlineLR(ICollection<double> attributeValues, ICollection<double> targetValues) {
58      if (attributeValues.Count != targetValues.Count) throw new ArgumentException("Targets and Attributes need to have the same length");
59      size = attributeValues.Count;
60
61      var yMean = targetValues.Average();
62      var xMean = attributeValues.Average();
63      targetMean = new NeumaierSum(yMean);
64      attributeMean = new NeumaierSum(xMean);
65      targetVarSum = new NeumaierSum(targetValues.VariancePop() * size);
66      attributeVarSum = new NeumaierSum(attributeValues.VariancePop() * size);
67      comoment = new NeumaierSum(attributeValues.Zip(targetValues, (x, y) => (x - xMean) * (y - yMean)).Sum());
68
69      var beta = comoment.Get() / attributeVarSum.Get();
70      var alpha = yMean - beta * xMean;
71      ssr = new NeumaierSum(attributeValues.Zip(targetValues, (x, y) => y - alpha - beta * x).Sum(x => x * x));
72    }
73
74    public void Add(double attributeValue, double targetValue) {
75      var predictOld = Predict(attributeValue, targetValue);
76
77      size++;
78      var dx = attributeValue - attributeMean.Get();
79      var dy = targetValue - targetMean.Get();
80      attributeMean.Add(dx / size);
81      targetMean.Add(dy / size);
82      var dx2 = attributeValue - attributeMean.Get();
83      var dy2 = targetValue - targetMean.Get();
84      attributeVarSum.Add(dx * dx2);
85      targetVarSum.Add(dy * dy2);
86      comoment.Add(dx * dy2);
87
88      ssr.Add(predictOld * Predict(attributeValue, targetValue));
89    }
90
91    public void Remove(double attributeValue, double targetValue) {
92      var predictOld = Predict(attributeValue, targetValue);
93
94      var dx2 = attributeValue - attributeMean.Get();
95      var dy2 = targetValue - targetMean.Get();
96      attributeMean.Mul(size / (size - 1.0));
97      targetMean.Mul(size / (size - 1.0));
98      attributeMean.Add(-attributeValue / (size - 1.0));
99      targetMean.Add(-targetValue / (size - 1.0));
100      var dx = attributeValue - attributeMean.Get();
101      var dy = targetValue - targetMean.Get();
102      attributeVarSum.Add(-dx * dx2);
103      targetVarSum.Add(-dy * dy2);
104      comoment.Add(-dx * dy2);
105      size--;
106
107      ssr.Add(-predictOld * Predict(attributeValue, targetValue));
108    }
109
110    private double Predict(double attributeValue, double targetValue) {
111      return targetValue - Alpha - Beta * attributeValue;
112    }
113  }
114}
Note: See TracBrowser for help on using the repository browser.