Changeset 5657
- Timestamp:
- 03/10/11 12:37:11 (14 years ago)
- Location:
- branches/DataAnalysis Refactoring
- Files:
-
- 5 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicDiscriminantFunctionClassificationModel.cs
r5649 r5657 62 62 : base(tree, interpreter) { 63 63 this.classValues = classValues.ToArray(); 64 this.thresholds = new double[0]; 64 65 } 65 66 … … 80 81 else break; 81 82 } 82 yield return classValues.ElementAt(classIndex );83 yield return classValues.ElementAt(classIndex - 1); 83 84 } 84 85 } -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicDiscriminantFunctionClassificationSolution.cs
r5649 r5657 61 61 public override IDeepCloneable Clone(Cloner cloner) { 62 62 return new SymbolicDiscriminantFunctionClassificationSolution(this, cloner); 63 } 63 } 64 64 } 65 65 } -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/DiscriminantFunctionClassificationSolution.cs
r5649 r5657 74 74 75 75 public IEnumerable<double> Thresholds { 76 get { return Model.Thresholds; } 76 get { 77 return Model.Thresholds; 78 } 79 protected set { Model.Thresholds = value; } 77 80 } 78 81 … … 88 91 } 89 92 #endregion 93 94 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 95 if (Model.Thresholds == null || Model.Thresholds.Count() == 0) RecalculateClassIntermediates(); 96 return base.GetEstimatedClassValues(rows); 97 } 98 99 private void RecalculateClassIntermediates() { 100 int slices = 100; 101 List<double> estimatedValues = EstimatedValues.ToList(); 102 List<int> classInstances = (from classValue in ProblemData.Dataset.GetVariableValues(ProblemData.TargetVariable) 103 group classValue by classValue into grouping 104 select grouping.Count()).ToList(); 105 double maxEstimatedValue = estimatedValues.Max(); 106 double minEstimatedValue = estimatedValues.Min(); 107 List<KeyValuePair<double, double>> estimatedTargetValues = 108 (from row in ProblemData.TrainingIndizes 109 select new KeyValuePair<double, double>( 110 estimatedValues[row], 111 ProblemData.Dataset[ProblemData.TargetVariable, row])).ToList(); 112 113 List<double> originalClasses = ProblemData.ClassValues.OrderBy(x => x).ToList(); 114 int nClasses = originalClasses.Distinct().Count(); 115 double[] thresholds = new double[nClasses + 1]; 116 thresholds[0] = double.NegativeInfinity; 117 thresholds[thresholds.Length - 1] = double.PositiveInfinity; 118 119 for (int i = 1; i < thresholds.Length - 1; i++) { 120 double lowerThreshold = thresholds[i - 1]; 121 double actualThreshold = minEstimatedValue; 122 double thresholdIncrement = (maxEstimatedValue - minEstimatedValue) / slices; 123 124 double lowestBestThreshold = double.NaN; 125 double highestBestThreshold = double.NaN; 126 double bestClassificationScore = double.PositiveInfinity; 127 bool seriesOfEqualClassificationScores = false; 128 129 while (actualThreshold < maxEstimatedValue) { 130 double classificationScore = 0.0; 131 132 foreach (KeyValuePair<double, double> estimatedTarget in estimatedTargetValues) { 133 //all positives 134 if (estimatedTarget.Value.IsAlmost(originalClasses[i - 1])) { 135 if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold) 136 //true positive 137 classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i - 1], originalClasses[i - 1]); 138 else 139 //false negative 140 classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i], originalClasses[i - 1]); 141 } 142 //all negatives 143 else { 144 if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold) 145 //false positive 146 classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i - 1], originalClasses[i]); 147 else 148 //true negative, consider only upper class 149 classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i], originalClasses[i]); 150 } 151 } 152 153 //new best classification score found 154 if (classificationScore < bestClassificationScore) { 155 bestClassificationScore = classificationScore; 156 lowestBestThreshold = actualThreshold; 157 highestBestThreshold = actualThreshold; 158 seriesOfEqualClassificationScores = true; 159 } 160 //equal classification scores => if seriesOfEqualClassifcationScores == true update highest threshold 161 else if (Math.Abs(classificationScore - bestClassificationScore) < double.Epsilon && seriesOfEqualClassificationScores) 162 highestBestThreshold = actualThreshold; 163 //worse classificatoin score found reset seriesOfEqualClassifcationScores 164 else seriesOfEqualClassificationScores = false; 165 166 actualThreshold += thresholdIncrement; 167 } 168 //scale lowest thresholds and highest found optimal threshold according to the misclassification matrix 169 double falseNegativePenalty = ProblemData.GetClassificationPenalty(originalClasses[i], originalClasses[i - 1]); 170 double falsePositivePenalty = ProblemData.GetClassificationPenalty(originalClasses[i - 1], originalClasses[i]); 171 thresholds[i] = (lowestBestThreshold * falsePositivePenalty + highestBestThreshold * falseNegativePenalty) / (falseNegativePenalty + falsePositivePenalty); 172 } 173 Thresholds = new List<double>(thresholds); 174 } 90 175 } 91 176 } -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IDiscriminantFunctionClassificationModel.cs
r5649 r5657 24 24 namespace HeuristicLab.Problems.DataAnalysis { 25 25 public interface IDiscriminantFunctionClassificationModel : IClassificationModel { 26 IEnumerable<double> Thresholds { get; }26 IEnumerable<double> Thresholds { get; set; } 27 27 event EventHandler ThresholdsChanged; 28 28 IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows); -
branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/OnlineEvaluators/OnlineAccuracyEvaluator.cs
r5649 r5657 34 34 throw new InvalidOperationException("No elements"); 35 35 else 36 return correctlyClassified / n;36 return correctlyClassified / (double)n; 37 37 } 38 38 }
Note: See TracChangeset
for help on using the changeset viewer.