Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification/3.3/Symbolic/SymbolicClassificationSolution.cs @ 4366

Last change on this file since 4366 was 4366, checked in by mkommend, 14 years ago

added draft version of classification (ticket #939)

File size: 8.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis.Regression.Symbolic;
30
31namespace HeuristicLab.Problems.DataAnalysis.Classification {
32  /// <summary>
33  /// Represents a solution for a symbolic regression problem which can be visualized in the GUI.
34  /// </summary>
35  [Item("SymbolicClassificationSolution", "Represents a solution for a symbolic classification problem which can be visualized in the GUI.")]
36  [StorableClass]
37  public sealed class SymbolicClassificationSolution : DataAnalysisSolution, IClassificationSolution {
38    private SymbolicClassificationSolution() : base() { }
39    public SymbolicClassificationSolution(ClassificationProblemData problemData, SymbolicRegressionModel model, double lowerEstimationLimit, double upperEstimationLimit)
40      : base(problemData, lowerEstimationLimit, upperEstimationLimit) {
41      this.Model = model;
42    }
43
44    public override Image ItemImage {
45      get { return HeuristicLab.Common.Resources.VS2008ImageLibrary.Function; }
46    }
47
48    public new ClassificationProblemData ProblemData {
49      get { return (ClassificationProblemData)base.ProblemData; }
50      set { base.ProblemData = value; }
51    }
52
53    public new SymbolicRegressionModel Model {
54      get { return (SymbolicRegressionModel)base.Model; }
55      set { base.Model = value; }
56    }
57
58    protected override void RecalculateEstimatedValues() {
59      estimatedValues =
60          (from x in Model.GetEstimatedValues(ProblemData, 0, ProblemData.Dataset.Rows)
61           let boundedX = Math.Min(UpperEstimationLimit, Math.Max(LowerEstimationLimit, x))
62           select double.IsNaN(boundedX) ? UpperEstimationLimit : boundedX).ToList();
63      RecalculateClassIntermediates();
64      OnEstimatedValuesChanged();
65    }
66
67    private void RecalculateClassIntermediates() {
68      int slices = 1000;
69
70      List<KeyValuePair<double, double>> estimatedTargetValues =
71        (from row in Enumerable.Range(ProblemData.TrainingSamplesStart.Value, ProblemData.TrainingSamplesEnd.Value - ProblemData.TrainingSamplesStart.Value)
72         select new KeyValuePair<double, double>(
73           estimatedValues[row],
74           ProblemData.Dataset[ProblemData.TargetVariable.Value, row])).ToList();
75
76      List<double> originalClasses = ProblemData.Dataset.GetVariableValues(ProblemData.TargetVariable.Value).Distinct().OrderBy(x => x).ToList();
77      int numberOfClasses = originalClasses.Count;
78
79      double[] thresholds = new double[numberOfClasses + 1];
80      thresholds[0] = double.NegativeInfinity;
81      thresholds[thresholds.Length - 1] = double.PositiveInfinity;
82
83
84      for (int i = 1; i < thresholds.Length - 1; i++) {
85        double lowerThreshold = thresholds[i - 1];
86        double actualThreshold = originalClasses[i - 1];
87        double thresholdIncrement = (originalClasses[i] - originalClasses[i - 1]) / slices;
88
89        double bestThreshold = double.NaN;
90        double bestQuality = double.NegativeInfinity;
91
92        while (actualThreshold < originalClasses[i]) {
93          int truePosivites = 0;
94          int falsePosivites = 0;
95          int trueNegatives = 0;
96          int falseNegatives = 0;
97
98          foreach (KeyValuePair<double, double> estimatedTarget in estimatedTargetValues) {
99            //all positives
100            if (estimatedTarget.Value.IsAlmost(originalClasses[i - 1])) {
101              if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold)
102                truePosivites++;
103              else
104                falseNegatives++;
105            }
106              //all negatives
107            else {
108              if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold)
109                falsePosivites++;
110              else
111                trueNegatives++;
112            }
113          }
114
115          //mkommend 30.08.2010
116          //matthews correlation coefficient taken from http://en.wikipedia.org/wiki/Matthews_correlation_coefficient
117          //MCC = [(TP * FP) - (FP * FN)] / sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))
118          double dividend = truePosivites * falsePosivites - falsePosivites * falseNegatives;
119          double divisor = Math.Sqrt((truePosivites + falsePosivites) * (truePosivites + falsePosivites) *
120            (trueNegatives + falsePosivites) * (trueNegatives + falseNegatives));
121          if (divisor == 0)
122            divisor = 1;
123
124          double mcc = dividend / divisor;
125
126          if (bestQuality < mcc) {
127            bestQuality = mcc;
128            bestThreshold = actualThreshold;
129          }
130          actualThreshold += thresholdIncrement;
131        }
132        thresholds[i] = bestThreshold;
133      }
134      this.optimalThresholds = new List<double>(thresholds);
135      this.actualThresholds = optimalThresholds;
136    }
137
138    #region properties
139    private List<double> optimalThresholds;
140    private List<double> actualThresholds;
141    public IEnumerable<double> Thresholds {
142      get {
143        if (actualThresholds == null) RecalculateEstimatedValues();
144        return actualThresholds;
145      }
146      set {
147        if (actualThresholds != null && actualThresholds.SequenceEqual(value))
148          return;
149        actualThresholds = new List<double>(value);
150        OnThresholdsChanged();
151      }
152    }
153
154    private List<double> estimatedValues;
155    public override IEnumerable<double> EstimatedValues {
156      get {
157        if (estimatedValues == null) RecalculateEstimatedValues();
158        return estimatedValues.AsEnumerable();
159      }
160    }
161
162    public IEnumerable<double> EstimatedClassValues {
163      get {
164        double[] classValues = ProblemData.SortedClassValues.ToArray();
165        foreach (double value in EstimatedValues) {
166          int classIndex = 0;
167          while (value > actualThresholds[classIndex + 1])
168            classIndex++;
169          yield return classValues[classIndex];
170        }
171      }
172    }
173
174    public override IEnumerable<double> EstimatedTrainingValues {
175      get {
176        if (estimatedValues == null) RecalculateEstimatedValues();
177        int start = ProblemData.TrainingSamplesStart.Value;
178        int n = ProblemData.TrainingSamplesEnd.Value - start;
179        return estimatedValues.Skip(start).Take(n).ToList();
180      }
181    }
182    public IEnumerable<double> EstimatedTrainingClassValues {
183      get {
184        int start = ProblemData.TrainingSamplesStart.Value;
185        int n = ProblemData.TrainingSamplesEnd.Value - start;
186        return EstimatedClassValues.Skip(start).Take(n).ToList();
187      }
188    }
189
190    public override IEnumerable<double> EstimatedTestValues {
191      get {
192        if (estimatedValues == null) RecalculateEstimatedValues();
193        int start = ProblemData.TestSamplesStart.Value;
194        int n = ProblemData.TestSamplesEnd.Value - start;
195        return estimatedValues.Skip(start).Take(n).ToList();
196      }
197    }
198    public IEnumerable<double> EstimatedTestClassValues {
199      get {
200        int start = ProblemData.TestSamplesStart.Value;
201        int n = ProblemData.TestSamplesEnd.Value - start;
202        return EstimatedClassValues.Skip(start).Take(n).ToList();
203      }
204    }
205    #endregion
206
207    public event EventHandler ThresholdsChanged;
208    private void OnThresholdsChanged() {
209      var handler = ThresholdsChanged;
210      if (handler != null)
211        ThresholdsChanged(this, EventArgs.Empty);
212    }
213  }
214}
Note: See TracBrowser for help on using the repository browser.