Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Classification/3.3/Symbolic/SymbolicClassificationSolution.cs @ 5370

Last change on this file since 5370 was 5341, checked in by mkommend, 14 years ago

Corrected minor flaw regarding equal classification scores during the calculation of the optimal thresholds (ticket #1383).

File size: 8.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis.Regression.Symbolic;
29
30namespace HeuristicLab.Problems.DataAnalysis.Classification {
31  /// <summary>
32  /// Represents a solution for a symbolic classification problem which can be visualized in the GUI.
33  /// </summary>
34  [Item("SymbolicClassificationSolution", "Represents a solution for a symbolic classification problem which can be visualized in the GUI.")]
35  [StorableClass]
36  public class SymbolicClassificationSolution : SymbolicRegressionSolution, IClassificationSolution {
37    public new ClassificationProblemData ProblemData {
38      get { return (ClassificationProblemData)base.ProblemData; }
39      set { base.ProblemData = value; }
40    }
41
42    #region properties
43    private List<double> optimalThresholds;
44    private List<double> actualThresholds;
45    public IEnumerable<double> Thresholds {
46      get {
47        if (actualThresholds == null) RecalculateEstimatedValues();
48        return actualThresholds;
49      }
50      set {
51        if (actualThresholds != null && actualThresholds.SequenceEqual(value))
52          return;
53        actualThresholds = new List<double>(value);
54        OnThresholdsChanged();
55      }
56    }
57
58    public IEnumerable<double> EstimatedClassValues {
59      get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
60    }
61
62    public IEnumerable<double> EstimatedTrainingClassValues {
63      get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
64    }
65
66    public IEnumerable<double> EstimatedTestClassValues {
67      get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
68    }
69
70    [StorableConstructor]
71    protected SymbolicClassificationSolution(bool deserializing) : base(deserializing) { }
72    protected SymbolicClassificationSolution(SymbolicClassificationSolution original, Cloner cloner) : base(original, cloner) { }
73    public SymbolicClassificationSolution(ClassificationProblemData problemData, SymbolicRegressionModel model, double lowerEstimationLimit, double upperEstimationLimit)
74      : base(problemData, model, lowerEstimationLimit, upperEstimationLimit) {
75    }
76
77    public override IDeepCloneable Clone(Cloner cloner) {
78      return new SymbolicClassificationSolution(this, cloner);
79    }
80
81    protected override void RecalculateEstimatedValues() {
82      estimatedValues =
83          (from x in Model.GetEstimatedValues(ProblemData, 0, ProblemData.Dataset.Rows)
84           let boundedX = Math.Min(UpperEstimationLimit, Math.Max(LowerEstimationLimit, x))
85           select double.IsNaN(boundedX) ? UpperEstimationLimit : boundedX).ToList();
86      RecalculateClassIntermediates();
87      OnEstimatedValuesChanged();
88    }
89
90    private void RecalculateClassIntermediates() {
91      int slices = 100;
92
93      List<int> classInstances = (from classValue in ProblemData.Dataset.GetVariableValues(ProblemData.TargetVariable.Value)
94                                  group classValue by classValue into grouping
95                                  select grouping.Count()).ToList();
96
97      List<KeyValuePair<double, double>> estimatedTargetValues =
98         (from row in ProblemData.TrainingIndizes
99          select new KeyValuePair<double, double>(
100            estimatedValues[row],
101            ProblemData.Dataset[ProblemData.TargetVariable.Value, row])).ToList();
102
103      List<double> originalClasses = ProblemData.SortedClassValues.ToList();
104      double[] thresholds = new double[ProblemData.NumberOfClasses + 1];
105      thresholds[0] = double.NegativeInfinity;
106      thresholds[thresholds.Length - 1] = double.PositiveInfinity;
107
108      for (int i = 1; i < thresholds.Length - 1; i++) {
109        double lowerThreshold = thresholds[i - 1];
110        double actualThreshold = originalClasses[i - 1];
111        double thresholdIncrement = (originalClasses[i] - originalClasses[i - 1]) / slices;
112
113        double lowestBestThreshold = double.NaN;
114        double highestBestThreshold = double.NaN;
115        double bestClassificationScore = double.PositiveInfinity;
116        bool seriesOfEqualClassificationScores = false;
117
118        while (actualThreshold < originalClasses[i]) {
119          double classificationScore = 0.0;
120
121          foreach (KeyValuePair<double, double> estimatedTarget in estimatedTargetValues) {
122            //all positives
123            if (estimatedTarget.Value.IsAlmost(originalClasses[i - 1])) {
124              if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold)
125                //true positive
126                classificationScore += ProblemData.MisclassificationMatrix[i - 1, i - 1];
127              else
128                //false negative
129                classificationScore += ProblemData.MisclassificationMatrix[i, i - 1];
130            }
131              //all negatives
132            else {
133              if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold)
134                //false positive
135                classificationScore += ProblemData.MisclassificationMatrix[i - 1, i];
136              else
137                //true negative, consider only upper class
138                classificationScore += ProblemData.MisclassificationMatrix[i, i];
139            }
140          }
141
142          //new best classification score found
143          if (classificationScore < bestClassificationScore) {
144            bestClassificationScore = classificationScore;
145            lowestBestThreshold = actualThreshold;
146            highestBestThreshold = actualThreshold;
147            seriesOfEqualClassificationScores = true;
148          }
149            //equal classification scores => if seriesOfEqualClassifcationScores == true update highest threshold
150          else if (Math.Abs(classificationScore - bestClassificationScore) < double.Epsilon && seriesOfEqualClassificationScores)
151            highestBestThreshold = actualThreshold;
152          //worse classificatoin score found reset seriesOfEqualClassifcationScores
153          else seriesOfEqualClassificationScores = false;
154
155          actualThreshold += thresholdIncrement;
156        }
157        //scale lowest thresholds and highest found optimal threshold according to the misclassification matrix
158        double falseNegativePenalty = ProblemData.MisclassificationMatrix[i, i - 1];
159        double falsePositivePenalty = ProblemData.MisclassificationMatrix[i - 1, i];
160        thresholds[i] = (lowestBestThreshold * falsePositivePenalty + highestBestThreshold * falseNegativePenalty) / (falseNegativePenalty + falsePositivePenalty);
161      }
162      this.optimalThresholds = new List<double>(thresholds);
163      this.actualThresholds = optimalThresholds;
164    }
165
166    public IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
167      double[] classValues = ProblemData.SortedClassValues.ToArray();
168      if (estimatedValues == null)
169        RecalculateEstimatedValues();
170      foreach (int row in rows) {
171        double value = estimatedValues[row];
172        int classIndex = 0;
173        while (value > actualThresholds[classIndex + 1])
174          classIndex++;
175        yield return classValues[classIndex];
176      }
177    }
178    #endregion
179
180    public event EventHandler ThresholdsChanged;
181    private void OnThresholdsChanged() {
182      var handler = ThresholdsChanged;
183      if (handler != null)
184        ThresholdsChanged(this, EventArgs.Empty);
185    }
186  }
187}
Note: See TracBrowser for help on using the repository browser.