Changeset 8814 for branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
- Timestamp:
- 10/16/12 15:19:40 (12 years ago)
- Location:
- branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators
- Files:
-
- 5 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/AverageThresholdCalculator.cs
r8534 r8814 57 57 } 58 58 59 protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue ) {59 protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) { 60 60 Dataset dataset = solutions.First().ProblemData.Dataset; 61 IList<double> values = solutions. Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();61 IList<double> values = solutions.Where(s => handler(s.ProblemData, index)).Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList(); 62 62 if (values.Count <= 0) 63 63 return double.NaN; … … 66 66 } 67 67 68 public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue ) {68 public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) { 69 69 Dataset dataset = solutions.First().ProblemData.Dataset; 70 double[][] values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, indices).ToArray()).ToArray(); 70 List<int> indicesList = indices.ToList(); 71 var solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedValues(dataset, indicesList).ToArray()); 71 72 double[] confidences = new double[indices.Count()]; 72 73 double[] estimatedClassValueArr = estimatedClassValue.ToArray(); 73 74 74 for (int i = 0; i < indices.Count(); i++) { 75 double avg = values.Select(x => x[i]).Average(); 76 confidences[i] = GetAverageConfidence(avg, estimatedClassValueArr[i]); 75 for (int i = 0; i < indicesList.Count; i++) { 76 var values = solValues.Where(x => handler(x.Key.ProblemData, indicesList[i])).Select(x => x.Value[i]); 77 if (values.Count() <= 0) { 78 confidences[i] = double.NaN; 79 } else { 80 double avg = values.Average(); 81 confidences[i] = GetAverageConfidence(avg, estimatedClassValueArr[i]); 82 } 77 83 } 78 84 … … 84 90 if (estimatedClassValue.Equals(classValues[i])) { 85 91 //special case: avgerage is higher than value of highest class 86 if (i == classValues.Length - 1 && avg > estimatedClassValue) {92 if (i == classValues.Length - 1 && avg >= estimatedClassValue) { 87 93 return 1; 88 94 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs
r8811 r8814 24 24 using HeuristicLab.Common; 25 25 using HeuristicLab.Core; 26 using HeuristicLab.Data;27 26 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 27 … … 112 111 } 113 112 114 public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue ) {113 public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) { 115 114 if (solutions.Count() < 1) 116 115 return double.NaN; 117 116 Dataset dataset = solutions.First().ProblemData.Dataset; 118 117 var correctSolutions = solutions.Select(s => new { Solution = s, Values = s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First() }) 119 .Where(a => a.Values.Equals(estimatedClassValue))118 .Where(a => handler(a.Solution.ProblemData, index) && a.Values.Equals(estimatedClassValue)) 120 119 .Select(a => a.Solution); 121 120 return (from sol in correctSolutions … … 123 122 } 124 123 125 public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue ) {124 public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) { 126 125 if (solutions.Count() < 1) 127 126 return Enumerable.Repeat(double.NaN, indices.Count()); 128 127 128 List<int> indicesList = indices.ToList(); 129 129 130 Dataset dataset = solutions.First().ProblemData.Dataset; 130 Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indices ).ToArray());131 Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indicesList).ToArray()); 131 132 double[] estimatedClassValueArr = estimatedClassValue.ToArray(); 132 double[] confidences = new double[indices .Count()];133 double[] confidences = new double[indicesList.Count]; 133 134 134 for (int i = 0; i < indices .Count(); i++) {135 for (int i = 0; i < indicesList.Count; i++) { 135 136 var correctSolutions = solValues.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i])); 136 137 confidences[i] = (from sol in correctSolutions 138 where handler(sol.Key.ProblemData, indicesList[i]) 137 139 select weights[sol.Key]).Sum(); 138 140 } … … 142 144 143 145 #region Helper 144 protected IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {145 return from i in indizes146 select targetValues[i];147 }148 146 protected bool PointInTraining(IClassificationProblemData problemData, int point) { 149 IntRange trainingPartition = problemData.TrainingPartition; 150 IntRange testPartition = problemData.TestPartition; 151 return (trainingPartition.Start <= point && point < trainingPartition.End) 152 && !(testPartition.Start <= point && point < testPartition.End); 147 return problemData.IsTrainingSample(point); 153 148 } 154 149 protected bool PointInTest(IClassificationProblemData problemData, int point) { 155 IntRange testPartition = problemData.TestPartition; 156 return testPartition.Start <= point && point < testPartition.End; 150 return problemData.IsTestSample(point); 157 151 } 158 152 protected bool AllPoints(IClassificationProblemData problemData, int point) { -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs
r8811 r8814 82 82 } 83 83 84 public sealed override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue ) {84 public sealed override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) { 85 85 if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution)) 86 86 return double.NaN; … … 88 88 IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>(); 89 89 90 return GetDiscriminantConfidence(discriminantSolutions, index, estimatedClassValue );90 return GetDiscriminantConfidence(discriminantSolutions, index, estimatedClassValue, handler); 91 91 } 92 92 93 protected virtual double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue ) {94 return base.GetConfidence(solutions, index, estimatedClassValue );93 protected virtual double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) { 94 return base.GetConfidence(solutions, index, estimatedClassValue, handler); 95 95 } 96 96 97 public sealed override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue ) {97 public sealed override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) { 98 98 if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution)) 99 99 return Enumerable.Repeat(double.NaN, indices.Count()); … … 101 101 IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>(); 102 102 103 return GetDiscriminantConfidence(discriminantSolutions, indices, estimatedClassValue );103 return GetDiscriminantConfidence(discriminantSolutions, indices, estimatedClassValue, handler); 104 104 } 105 105 106 public virtual IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue ) {107 return base.GetConfidence(solutions, indices, estimatedClassValue );106 public virtual IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) { 107 return base.GetConfidence(solutions, indices, estimatedClassValue, handler); 108 108 } 109 109 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MajorityVoteWeightCalculator.cs
r8297 r8814 52 52 } 53 53 54 public override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue ) {54 public override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) { 55 55 if (solutions.Count() < 1) 56 56 return double.NaN; 57 Dataset dataset = solutions.First().ProblemData.Dataset; 58 int correctEstimated = solutions.Select(s => s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First()) 57 var votingSolutions = solutions.Where(s => handler(s.ProblemData, index)); 58 if (votingSolutions.Count() < 1) 59 return double.NaN; 60 Dataset dataset = votingSolutions.First().ProblemData.Dataset; 61 int correctEstimated = votingSolutions.Select(s => s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First()) 59 62 .Where(x => x.Equals(estimatedClassValue)) 60 63 .Count(); 61 return ((double)correctEstimated / (double) solutions.Count() - 0.5) * 2;64 return ((double)correctEstimated / (double)votingSolutions.Count() - 0.5) * 2; 62 65 } 63 66 64 public override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue ) {67 public override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) { 65 68 if (solutions.Count() < 1) 66 69 return Enumerable.Repeat(double.NaN, indices.Count()); 70 71 List<int> indicesList = indices.ToList(); 67 72 Dataset dataset = solutions.First().ProblemData.Dataset; 68 var estimationsPerSolution = solutions.Select(s => s.Model.GetEstimatedClassValues(dataset, indices).ToArray()).ToArray();73 var solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indicesList).ToArray()); 69 74 double[] estimatedClassValueArr = estimatedClassValue.ToArray(); 70 75 double correctEstimated; 71 76 double[] confidences = new double[indices.Count()]; 72 77 73 for (int i = 0; i < indices.Count(); i++) { 74 correctEstimated = estimationsPerSolution.Where(x => DoubleExtensions.IsAlmost(x[i], estimatedClassValueArr[i])).Count(); 75 confidences[i] = (correctEstimated / (double)solutions.Count() - 0.5) * 2; 78 for (int i = 0; i < indicesList.Count; i++) { 79 var votingSolutions = solValues.Where(x => handler(x.Key.ProblemData, indicesList[i])); 80 correctEstimated = votingSolutions.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i])).Count(); 81 confidences[i] = (correctEstimated / (double)votingSolutions.Count() - 0.5) * 2; 76 82 } 77 83 -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MedianThresholdCalculator.cs
r8534 r8814 57 57 } 58 58 59 protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) { 60 59 protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) { 61 60 Dataset dataset = solutions.First().ProblemData.Dataset; 62 IList<double> values = solutions. Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();61 IList<double> values = solutions.Where(s => handler(s.ProblemData, index)).Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList(); 63 62 if (values.Count <= 0) 64 63 return double.NaN; … … 67 66 } 68 67 69 public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue ) {68 public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) { 70 69 71 70 Dataset dataset = solutions.First().ProblemData.Dataset; 72 double[][] values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, indices).ToArray()).ToArray(); 71 List<int> indicesList = indices.ToList(); 72 var solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedValues(dataset, indicesList).ToArray()); 73 73 double[] confidences = new double[indices.Count()]; 74 74 double[] estimatedClassValueArr = estimatedClassValue.ToArray(); 75 75 76 for (int i = 0; i < indices.Count(); i++) { 77 double avg = values.Select(x => x[i]).Average(); 78 confidences[i] = GetMedianConfidence(avg, estimatedClassValueArr[i]); 76 for (int i = 0; i < indicesList.Count; i++) { 77 var values = solValues.Where(x => handler(x.Key.ProblemData, indicesList[i])).Select(x => x.Value[i]).ToList(); 78 if (values.Count <= 0) { 79 confidences[i] = double.NaN; 80 } else { 81 double median = GetMedian(values); 82 confidences[i] = GetMedianConfidence(median, estimatedClassValueArr[i]); 83 } 79 84 } 80 85 … … 85 90 for (int i = 0; i < classValues.Length; i++) { 86 91 if (estimatedClassValue.Equals(classValues[i])) { 87 //special case: avgerageis higher than value of highest class88 if (i == classValues.Length - 1 && median > estimatedClassValue) {92 //special case: median is higher than value of highest class 93 if (i == classValues.Length - 1 && median >= estimatedClassValue) { 89 94 return 1; 90 95 } 91 //special case: averageis lower than value of lowest class96 //special case: median is lower than value of lowest class 92 97 if (i == 0 && median < estimatedClassValue) { 93 98 return 1; 94 99 } 95 //special case: averageis not between threshold of estimated class value100 //special case: median is not between threshold of estimated class value 96 101 if ((i < classValues.Length - 1 && median >= threshold[i + 1]) || median <= threshold[i]) { 97 102 return 0; 98 103 } 99 104 100 double thresholdToClassDistance, thresholdTo AverageValueDistance;105 double thresholdToClassDistance, thresholdToMedianValueDistance; 101 106 if (median >= classValues[i]) { 102 107 thresholdToClassDistance = threshold[i + 1] - classValues[i]; 103 thresholdTo AverageValueDistance = threshold[i + 1] - median;108 thresholdToMedianValueDistance = threshold[i + 1] - median; 104 109 } else { 105 110 thresholdToClassDistance = classValues[i] - threshold[i]; 106 thresholdTo AverageValueDistance = median - threshold[i];111 thresholdToMedianValueDistance = median - threshold[i]; 107 112 } 108 return (1 / thresholdToClassDistance) * thresholdTo AverageValueDistance;113 return (1 / thresholdToClassDistance) * thresholdToMedianValueDistance; 109 114 } 110 115 }
Note: See TracChangeset
for help on using the changeset viewer.