- Timestamp:
- 07/17/12 15:30:04 (12 years ago)
- Location:
- branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4
- Files:
-
- 7 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
r8177 r8297 28 28 using HeuristicLab.Data; 29 29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 30 using HeuristicLab.Problems.DataAnalysis.Interfaces .Classification;30 using HeuristicLab.Problems.DataAnalysis.Interfaces; 31 31 32 32 namespace HeuristicLab.Problems.DataAnalysis { -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/AverageThresholdCalculator.cs
r8101 r8297 66 66 return double.NaN; 67 67 double avg = values.Average(); 68 return GetAverageConfidence(avg, estimatedClassValue); 69 } 70 71 public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) { 72 if (!classValues.Count().Equals(2)) 73 return Enumerable.Repeat(double.NaN, indices.Count()); 74 75 Dataset dataset = solutions.First().ProblemData.Dataset; 76 double[][] values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, indices).ToArray()).ToArray(); 77 double[] confidences = new double[indices.Count()]; 78 double[] estimatedClassValueArr = estimatedClassValue.ToArray(); 79 80 for (int i = 0; i < indices.Count(); i++) { 81 double avg = values.Select(x => x[i]).Average(); 82 confidences[i] = GetAverageConfidence(avg, estimatedClassValueArr[i]); 83 } 84 85 return confidences; 86 } 87 88 protected double GetAverageConfidence(double avg, double estimatedClassValue) { 68 89 if (estimatedClassValue.Equals(classValues[0])) { 69 90 if (avg < estimatedClassValue) -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs
r7562 r8297 26 26 using HeuristicLab.Data; 27 27 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 using HeuristicLab.Problems.DataAnalysis.Interfaces .Classification;28 using HeuristicLab.Problems.DataAnalysis.Interfaces; 29 29 30 30 namespace HeuristicLab.Problems.DataAnalysis { … … 124 124 } 125 125 126 public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) { 127 if (solutions.Count() < 1) 128 return Enumerable.Repeat(double.NaN, indices.Count()); 129 130 Dataset dataset = solutions.First().ProblemData.Dataset; 131 Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indices).ToArray()); 132 double[] estimatedClassValueArr = estimatedClassValue.ToArray(); 133 double[] confidences = new double[indices.Count()]; 134 135 for (int i = 0; i < indices.Count(); i++) { 136 var correctSolutions = solValues.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i])); 137 confidences[i] = (from sol in correctSolutions 138 select weights[sol.Key]).Sum(); 139 } 140 141 return confidences; 142 } 143 126 144 #region Helper 127 145 protected IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) { -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs
r8177 r8297 24 24 using HeuristicLab.Common; 25 25 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 26 using HeuristicLab.Problems.DataAnalysis.Interfaces .Classification;26 using HeuristicLab.Problems.DataAnalysis.Interfaces; 27 27 28 28 namespace HeuristicLab.Problems.DataAnalysis { … … 95 95 return base.GetConfidence(solutions, index, estimatedClassValue); 96 96 } 97 98 public sealed override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) { 99 if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution)) 100 return Enumerable.Repeat(double.NaN, indices.Count()); 101 102 IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>(); 103 104 return GetDiscriminantConfidence(discriminantSolutions, indices, estimatedClassValue); 105 } 106 107 public virtual IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) { 108 return base.GetConfidence(solutions, indices, estimatedClassValue); 109 } 97 110 } 98 111 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MajorityVoteWeightCalculator.cs
r7866 r8297 61 61 return ((double)correctEstimated / (double)solutions.Count() - 0.5) * 2; 62 62 } 63 64 public override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) { 65 if (solutions.Count() < 1) 66 return Enumerable.Repeat(double.NaN, indices.Count()); 67 Dataset dataset = solutions.First().ProblemData.Dataset; 68 var estimationsPerSolution = solutions.Select(s => s.Model.GetEstimatedClassValues(dataset, indices).ToArray()).ToArray(); 69 double[] estimatedClassValueArr = estimatedClassValue.ToArray(); 70 double correctEstimated; 71 double[] confidences = new double[indices.Count()]; 72 73 for (int i = 0; i < indices.Count(); i++) { 74 correctEstimated = estimationsPerSolution.Where(x => DoubleExtensions.IsAlmost(x[i], estimatedClassValueArr[i])).Count(); 75 confidences[i] = (correctEstimated / (double)solutions.Count() - 0.5) * 2; 76 } 77 78 return confidences; 79 } 63 80 } 64 81 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MedianThresholdCalculator.cs
r8101 r8297 66 66 return double.NaN; 67 67 double median = GetMedian(values); 68 return GetMedianConfidence(median, estimatedClassValue); 69 } 70 71 public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) { 72 if (!classValues.Count().Equals(2)) 73 return Enumerable.Repeat(double.NaN, indices.Count()); 74 75 Dataset dataset = solutions.First().ProblemData.Dataset; 76 double[][] values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, indices).ToArray()).ToArray(); 77 double[] confidences = new double[indices.Count()]; 78 double[] estimatedClassValueArr = estimatedClassValue.ToArray(); 79 80 for (int i = 0; i < indices.Count(); i++) { 81 double avg = values.Select(x => x[i]).Average(); 82 confidences[i] = GetMedianConfidence(avg, estimatedClassValueArr[i]); 83 } 84 85 return confidences; 86 } 87 88 protected double GetMedianConfidence(double median, double estimatedClassValue) { 68 89 if (estimatedClassValue.Equals(classValues[0])) { 69 90 if (median < estimatedClassValue) -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationEnsembleSolutionWeightCalculator.cs
r7562 r8297 23 23 using HeuristicLab.Core; 24 24 25 namespace HeuristicLab.Problems.DataAnalysis.Interfaces .Classification{25 namespace HeuristicLab.Problems.DataAnalysis.Interfaces { 26 26 public delegate bool CheckPoint(IClassificationProblemData problemData, int point); 27 27 … … 30 30 IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler); 31 31 double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue); 32 IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue); 32 33 33 34 CheckPoint GetTestClassDelegate();
Note: See TracChangeset
for help on using the changeset viewer.