Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/DiscriminantFunctionClassificationSolution.cs @ 5664

Last change on this file since 5664 was 5664, checked in by gkronber, 14 years ago

#1418 ported ROC, confusion matrix and discriminant function classification views and fixed bug in threshold calculation.

File size: 8.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Operators;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Optimization;
31using System;
32
33namespace HeuristicLab.Problems.DataAnalysis {
34  /// <summary>
35  /// Represents a classification solution that uses a discriminant function and classification thresholds.
36  /// </summary>
37  [StorableClass]
38  [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")]
39  public class DiscriminantFunctionClassificationSolution : ClassificationSolution, IDiscriminantFunctionClassificationSolution {
40    [StorableConstructor]
41    protected DiscriminantFunctionClassificationSolution(bool deserializing) : base(deserializing) { }
42    protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner)
43      : base(original, cloner) {
44    }
45    public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData)
46      : this(new DiscriminantFunctionClassificationModel(model, problemData.ClassValues), problemData) {
47    }
48    public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
49      : base(model, problemData) {
50      Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
51    }
52
53    #region IDiscriminantFunctionClassificationSolution Members
54
55    public new IDiscriminantFunctionClassificationModel Model {
56      get { return (IDiscriminantFunctionClassificationModel)base.Model; }
57    }
58
59    public IEnumerable<double> EstimatedValues {
60      get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
61    }
62
63    public IEnumerable<double> EstimatedTrainingValues {
64      get { return GetEstimatedValues(ProblemData.TrainingIndizes); }
65    }
66
67    public IEnumerable<double> EstimatedTestValues {
68      get { return GetEstimatedValues(ProblemData.TestIndizes); }
69    }
70
71    public IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
72      return Model.GetEstimatedValues(ProblemData.Dataset, rows);
73    }
74
75    public IEnumerable<double> Thresholds {
76      get {
77        return Model.Thresholds;
78      }
79      set { Model.Thresholds = new List<double>(value); }
80    }
81
82    public event EventHandler ThresholdsChanged;
83
84    private void Model_ThresholdsChanged(object sender, EventArgs e) {
85      OnThresholdsChanged(e);
86    }
87
88    protected virtual void OnThresholdsChanged(EventArgs e) {
89      var listener = ThresholdsChanged;
90      if (listener != null) listener(this, e);
91    }
92    #endregion
93
94    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
95      if (Model.Thresholds == null || Model.Thresholds.Count() == 0) RecalculateClassIntermediates();
96      return base.GetEstimatedClassValues(rows);
97    }
98
99    private void RecalculateClassIntermediates() {
100      int slices = 100;
101      List<double> estimatedValues = EstimatedValues.ToList();
102      List<int> classInstances = (from classValue in ProblemData.Dataset.GetVariableValues(ProblemData.TargetVariable)
103                                  group classValue by classValue into grouping
104                                  select grouping.Count()).ToList();
105      double maxEstimatedValue = estimatedValues.Max();
106      double minEstimatedValue = estimatedValues.Min();
107      List<KeyValuePair<double, double>> estimatedTargetValues =
108         (from row in ProblemData.TrainingIndizes
109          select new KeyValuePair<double, double>(
110            estimatedValues[row],
111            ProblemData.Dataset[ProblemData.TargetVariable, row])).ToList();
112
113      List<double> originalClasses = ProblemData.ClassValues.OrderBy(x => x).ToList();
114      int nClasses = originalClasses.Distinct().Count();
115      double[] thresholds = new double[nClasses + 1];
116      thresholds[0] = double.NegativeInfinity;
117      thresholds[thresholds.Length - 1] = double.PositiveInfinity;
118
119      double thresholdIncrement = (maxEstimatedValue - minEstimatedValue) / slices;
120
121      for (int i = 1; i < thresholds.Length - 1; i++) {
122        double lowerThreshold = thresholds[i - 1];
123        double actualThreshold = Math.Max(lowerThreshold, minEstimatedValue);
124        double lowestBestThreshold = double.NaN;
125        double highestBestThreshold = double.NaN;
126        double bestClassificationScore = double.PositiveInfinity;
127        bool seriesOfEqualClassificationScores = false;
128
129        while (actualThreshold < maxEstimatedValue) {
130          double classificationScore = 0.0;
131
132          foreach (KeyValuePair<double, double> estimatedTarget in estimatedTargetValues) {
133            //all positives
134            if (estimatedTarget.Value.IsAlmost(originalClasses[i - 1])) {
135              if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold)
136                //true positive
137                classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i - 1], originalClasses[i - 1]);
138              else
139                //false negative
140                classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i], originalClasses[i - 1]);
141            }
142              //all negatives
143            else {
144              if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold)
145                //false positive
146                classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i - 1], originalClasses[i]);
147              else
148                //true negative, consider only upper class
149                classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i], originalClasses[i]);
150            }
151          }
152
153          //new best classification score found
154          if (classificationScore < bestClassificationScore) {
155            bestClassificationScore = classificationScore;
156            lowestBestThreshold = actualThreshold;
157            highestBestThreshold = actualThreshold;
158            seriesOfEqualClassificationScores = true;
159          }
160            //equal classification scores => if seriesOfEqualClassifcationScores == true update highest threshold
161          else if (Math.Abs(classificationScore - bestClassificationScore) < double.Epsilon && seriesOfEqualClassificationScores)
162            highestBestThreshold = actualThreshold;
163          //worse classificatoin score found reset seriesOfEqualClassifcationScores
164          else seriesOfEqualClassificationScores = false;
165
166          actualThreshold += thresholdIncrement;
167        }
168        //scale lowest thresholds and highest found optimal threshold according to the misclassification matrix
169        double falseNegativePenalty = ProblemData.GetClassificationPenalty(originalClasses[i], originalClasses[i - 1]);
170        double falsePositivePenalty = ProblemData.GetClassificationPenalty(originalClasses[i - 1], originalClasses[i]);
171        thresholds[i] = (lowestBestThreshold * falsePositivePenalty + highestBestThreshold * falseNegativePenalty) / (falseNegativePenalty + falsePositivePenalty);
172      }
173      Thresholds = new List<double>(thresholds);
174    }
175  }
176}
Note: See TracBrowser for help on using the repository browser.