Changeset 7549 for branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MedianThresholdCalculator.cs
- Timestamp:
- 03/05/12 17:02:37 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MedianThresholdCalculator.cs
r7531 r7549 46 46 protected double[] classValues; 47 47 48 protected override IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions) { 49 List<List<double>> estimatedTrainingValEnumerators = new List<List<double>>(); 50 List<List<double>> estimatedTrainingClassValEnumerators = new List<List<double>>(); 48 /// <summary> 49 /// 50 /// </summary> 51 /// <param name="discriminantSolutions"></param> 52 /// <returns>median instead of weights, because it doesn't use any weights</returns> 53 protected override IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions) { 54 List<List<double>> estimatedValues = new List<List<double>>(); 55 List<List<double>> estimatedClassValues = new List<List<double>>(); 56 57 List<IClassificationProblemData> solutionProblemData = discriminantSolutions.Select(sol => sol.ProblemData).ToList(); 58 Dataset dataSet = solutionProblemData[0].Dataset; 59 IEnumerable<int> rows = Enumerable.Range(0, dataSet.Rows); 51 60 foreach (var solution in discriminantSolutions) { 52 estimated TrainingValEnumerators.Add(solution.EstimatedTrainingValues.ToList());53 estimated TrainingClassValEnumerators.Add(solution.EstimatedTrainingClassValues.ToList());61 estimatedValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList()); 62 estimatedClassValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList()); 54 63 } 55 64 56 65 List<double> median = new List<double>(); 57 58 IClassificationProblemData problemData = discriminantSolutions.ElementAt(0).ProblemData; 59 List<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(); 60 IEnumerable<double> trainingVal = GetValues(targetValues, problemData.TrainingIndizes); 61 62 for (int i = 0; i < estimatedTrainingClassValEnumerators.First().Count; i++) { 63 var points = (from solution in estimatedTrainingValEnumerators 64 select solution[i]) 65 .OrderBy(p => p) 66 .ToList(); 67 68 median.Add(GetMedian(points)); 66 List<double> targetValues = dataSet.GetDoubleValues(solutionProblemData[0].TargetVariable).ToList(); 67 IList<double> curTrainingpoints = new List<double>(); 68 int removed = 0; 69 int count = targetValues.Count; 70 for (int point = 0; point < count; point++) { 71 curTrainingpoints.Clear(); 72 for (int solutionPos = 0; solutionPos < solutionProblemData.Count; solutionPos++) { 73 if (PointInTraining(solutionProblemData[solutionPos], point)) { 74 curTrainingpoints.Add(estimatedValues[solutionPos][point]); 75 } 76 } 77 if (curTrainingpoints.Count > 0) 78 median.Add(GetMedian(curTrainingpoints.OrderBy(p => p).ToList())); 79 else { 80 //remove not used points 81 targetValues.RemoveAt(point - removed); 82 removed++; 83 } 69 84 } 70 AccuracyMaximizationThresholdCalculator.CalculateThresholds( problemData, median, trainingVal, out classValues, out threshold);85 AccuracyMaximizationThresholdCalculator.CalculateThresholds(solutionProblemData[0], median, targetValues, out classValues, out threshold); 71 86 return median; 72 87 } 73 88 74 protected override double DiscriminantAggregateEstimatedClassValues(I Enumerable<double> estimatedClassValues, IEnumerable<double> estimatedValues) {89 protected override double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IClassificationSolution, double> estimatedValues) { 75 90 double classValue = classValues.First(); 76 double median = GetMedian(estimatedValues.ToList()); 91 IList<double> values = estimatedValues.Select(x => x.Value).ToList(); 92 if (values.Count <= 0) 93 return double.NaN; 94 double median = GetMedian(values); 77 95 for (int i = 0; i < classValues.Count(); i++) { 78 96 if (median > threshold[i]) … … 87 105 int count = estimatedValues.Count; 88 106 if (count % 2 == 0) 89 return 0.5 * (estimatedValues[count / 2 ] + estimatedValues[count / 2 + 1]);107 return 0.5 * (estimatedValues[count / 2 - 1] + estimatedValues[count / 2]); 90 108 else 91 return estimatedValues[ (count + 1)/ 2];109 return estimatedValues[count / 2]; 92 110 } 93 111 }
Note: See TracChangeset
for help on using the changeset viewer.