Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DataAnalysis/3.4/OnlineCalculators/OnlineWeightedClassificationMeanSquaredErrorCalculator.cs @ 16892

Last change on this file since 16892 was 16892, checked in by gkronber, 5 years ago

#2925 merged r16661:16890 from trunk to branch

File size: 5.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using HeuristicLab.Common;
25
26namespace HeuristicLab.Problems.DataAnalysis {
27  public class OnlineWeightedClassificationMeanSquaredErrorCalculator : IOnlineCalculator {
28
29    private double sse;
30    private int n;
31    public double WeightedResidualsMeanSquaredError {
32      get {
33        return n > 0 ? sse / n : 0.0;
34      }
35    }
36
37    public double PositiveClassValue { get; private set; }
38    public double ClassValuesMax { get; private set; }
39    public double ClassValuesMin { get; private set; }
40    public double DefiniteResidualsWeight { get; private set; }
41    public double PositiveClassResidualsWeight { get; private set; }
42    public double NegativeClassesResidualsWeight { get; private set; }
43
44    public OnlineWeightedClassificationMeanSquaredErrorCalculator(double positiveClassValue, double classValuesMax, double classValuesMin,
45                                                                double definiteResidualsWeight, double positiveClassResidualsWeight, double negativeClassesResidualsWeight) {
46      PositiveClassValue = positiveClassValue;
47      ClassValuesMax = classValuesMax;
48      ClassValuesMin = classValuesMin;
49      DefiniteResidualsWeight = definiteResidualsWeight;
50      PositiveClassResidualsWeight = positiveClassResidualsWeight;
51      NegativeClassesResidualsWeight = negativeClassesResidualsWeight;
52      Reset();
53    }
54
55    #region IOnlineCalculator Members
56    private OnlineCalculatorError errorState;
57    public OnlineCalculatorError ErrorState {
58      get { return errorState; }
59    }
60    public double Value {
61      get { return WeightedResidualsMeanSquaredError; }
62    }
63    public void Reset() {
64      n = 0;
65      sse = 0.0;
66      errorState = OnlineCalculatorError.InsufficientElementsAdded;
67    }
68
69    public void Add(double original, double estimated) {
70      if (double.IsNaN(estimated) || double.IsInfinity(estimated) ||
71          double.IsNaN(original) || double.IsInfinity(original) || (errorState & OnlineCalculatorError.InvalidValueAdded) > 0) {
72        errorState = errorState | OnlineCalculatorError.InvalidValueAdded;
73      } else {
74        double error = estimated - original;
75        double weight;
76        //apply weight
77        if (estimated > ClassValuesMax || estimated < ClassValuesMin) {
78          weight = DefiniteResidualsWeight;
79        } else if (original.IsAlmost(PositiveClassValue)) {
80          weight = PositiveClassResidualsWeight;
81        } else {
82          weight = NegativeClassesResidualsWeight;
83        }
84        sse += error * error * weight;
85        n++;
86        errorState = errorState & (~OnlineCalculatorError.InsufficientElementsAdded);        // n >= 1
87      }
88    }
89    #endregion
90
91    public static double Calculate(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues, double positiveClassValue, double classValuesMax, double classValuesMin,
92                                                                double definiteResidualsWeight, double positiveClassResidualsWeight, double negativeClassesResidualsWeight, out OnlineCalculatorError errorState) {
93      IEnumerator<double> originalEnumerator = originalValues.GetEnumerator();
94      IEnumerator<double> estimatedEnumerator = estimatedValues.GetEnumerator();
95      OnlineWeightedClassificationMeanSquaredErrorCalculator calculator = new OnlineWeightedClassificationMeanSquaredErrorCalculator(positiveClassValue, classValuesMax, classValuesMin, definiteResidualsWeight, positiveClassResidualsWeight, negativeClassesResidualsWeight);
96
97      // always move forward both enumerators (do not use short-circuit evaluation!)
98      while (originalEnumerator.MoveNext() & estimatedEnumerator.MoveNext()) {
99        double original = originalEnumerator.Current;
100        double estimated = estimatedEnumerator.Current;
101        calculator.Add(original, estimated);
102        if (calculator.ErrorState != OnlineCalculatorError.None) break;
103      }
104
105      // check if both enumerators are at the end to make sure both enumerations have the same length
106      if (calculator.ErrorState == OnlineCalculatorError.None &&
107         (estimatedEnumerator.MoveNext() || originalEnumerator.MoveNext())) {
108        throw new ArgumentException("Number of elements in originalValues and estimatedValues enumerations doesn't match.");
109      } else {
110        errorState = calculator.ErrorState;
111        return calculator.WeightedResidualsMeanSquaredError;
112      }
113    }
114  }
115}
Note: See TracBrowser for help on using the repository browser.