Changeset 8814 for branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs
- Timestamp:
- 10/16/12 15:19:40 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs
r8811 r8814 24 24 using HeuristicLab.Common; 25 25 using HeuristicLab.Core; 26 using HeuristicLab.Data;27 26 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 27 … … 112 111 } 113 112 114 public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue ) {113 public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) { 115 114 if (solutions.Count() < 1) 116 115 return double.NaN; 117 116 Dataset dataset = solutions.First().ProblemData.Dataset; 118 117 var correctSolutions = solutions.Select(s => new { Solution = s, Values = s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First() }) 119 .Where(a => a.Values.Equals(estimatedClassValue))118 .Where(a => handler(a.Solution.ProblemData, index) && a.Values.Equals(estimatedClassValue)) 120 119 .Select(a => a.Solution); 121 120 return (from sol in correctSolutions … … 123 122 } 124 123 125 public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue ) {124 public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) { 126 125 if (solutions.Count() < 1) 127 126 return Enumerable.Repeat(double.NaN, indices.Count()); 128 127 128 List<int> indicesList = indices.ToList(); 129 129 130 Dataset dataset = solutions.First().ProblemData.Dataset; 130 Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indices ).ToArray());131 Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indicesList).ToArray()); 131 132 double[] estimatedClassValueArr = estimatedClassValue.ToArray(); 132 double[] confidences = new double[indices .Count()];133 double[] confidences = new double[indicesList.Count]; 133 134 134 for (int i = 0; i < indices .Count(); i++) {135 for (int i = 0; i < indicesList.Count; i++) { 135 136 var correctSolutions = solValues.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i])); 136 137 confidences[i] = (from sol in correctSolutions 138 where handler(sol.Key.ProblemData, indicesList[i]) 137 139 select weights[sol.Key]).Sum(); 138 140 } … … 142 144 143 145 #region Helper 144 protected IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {145 return from i in indizes146 select targetValues[i];147 }148 146 protected bool PointInTraining(IClassificationProblemData problemData, int point) { 149 IntRange trainingPartition = problemData.TrainingPartition; 150 IntRange testPartition = problemData.TestPartition; 151 return (trainingPartition.Start <= point && point < trainingPartition.End) 152 && !(testPartition.Start <= point && point < testPartition.End); 147 return problemData.IsTrainingSample(point); 153 148 } 154 149 protected bool PointInTest(IClassificationProblemData problemData, int point) { 155 IntRange testPartition = problemData.TestPartition; 156 return testPartition.Start <= point && point < testPartition.End; 150 return problemData.IsTestSample(point); 157 151 } 158 152 protected bool AllPoints(IClassificationProblemData problemData, int point) {
Note: See TracChangeset
for help on using the changeset viewer.