Changeset 8814
- Timestamp:
- 10/16/12 15:19:40 (12 years ago)
- Location:
- branches/ClassificationEnsembleVoting
- Files:
-
- 2 deleted
- 10 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationEnsembleSolutionAccuracyToCoveredSamples.cs
r8811 r8814 26 26 using System.Windows.Forms; 27 27 using System.Windows.Forms.DataVisualization.Charting; 28 using HeuristicLab.Data;29 28 using HeuristicLab.MainForm; 30 29 … … 82 81 IClassificationEnsembleSolutionWeightCalculator weightCalc = Content.WeightCalculator; 83 82 var solutions = Content.ClassificationSolutions; 84 double[] estimatedClassValues ;83 double[] estimatedClassValues = null; 85 84 double[] classValues; 86 85 OnlineAccuracyCalculator accuracyCalc = new OnlineAccuracyCalculator(); 87 86 88 int rows; 89 double[] confidences; 87 int rows = 0; 88 double[] confidences = null; 89 90 classValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToArray(); 90 91 91 92 if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxAllSamples)) { 92 93 rows = Content.ProblemData.Dataset.Rows; 93 94 estimatedClassValues = Content.EstimatedClassValues.ToArray(); 94 classValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToArray(); 95 confidences = weightCalc.GetConfidence(solutions, Enumerable.Range(0, rows), estimatedClassValues).ToArray(); 96 } else { 97 IntRange range; 98 if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTrainingSamples)) { 99 range = Content.ProblemData.TrainingPartition; 100 estimatedClassValues = Content.EstimatedTrainingClassValues.ToArray(); 101 } else if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTestSamples)) { 102 range = Content.ProblemData.TestPartition; 103 estimatedClassValues = Content.EstimatedTestClassValues.ToArray(); 104 } else { 105 return; 106 } 107 rows = range.End - range.Start; 108 classValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable) 109 .Skip(range.Start).Take(range.End - range.Start).ToArray(); 110 confidences = new double[rows]; 111 int index; 112 for (int i = 0; i < rows; i++) { 113 index = range.Start + i; 114 confidences[i] = weightCalc.GetConfidence(GetRelevantSolutions(SamplesComboBox.SelectedItem.ToString(), solutions, index), 115 index, estimatedClassValues[i]); 116 } 95 confidences = weightCalc.GetConfidence(solutions, 96 Enumerable.Range(0, Content.ProblemData.Dataset.Rows), 97 estimatedClassValues, 98 weightCalc.GetAllClassDelegate()).ToArray(); 99 } else if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTrainingSamples)) { 100 rows = Content.ProblemData.TrainingIndices.Count(); 101 estimatedClassValues = Content.EstimatedTrainingClassValues.ToArray(); 102 confidences = weightCalc.GetConfidence(solutions, 103 Content.ProblemData.TrainingIndices, 104 estimatedClassValues, 105 weightCalc.GetTrainingClassDelegate()).ToArray(); 106 } else if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTestSamples)) { 107 rows = Content.ProblemData.TestIndices.Count(); 108 estimatedClassValues = Content.EstimatedTestClassValues.ToArray(); 109 confidences = weightCalc.GetConfidence(solutions, 110 Content.ProblemData.TestIndices, 111 estimatedClassValues, 112 weightCalc.GetTestClassDelegate()).ToArray(); 117 113 } 118 114 … … 130 126 131 127 accuracy[i + 1] = accuracyCalc.Accuracy; 132 covered[i] = 1.0 - (double)notCovered / (double)rows; 128 if (rows > 0) { 129 covered[i] = 1.0 - (double)notCovered / (double)rows; 130 } 133 131 accuracyCalc.Reset(); 134 132 } … … 179 177 } 180 178 181 protected IEnumerable<IClassificationSolution> GetRelevantSolutions(string samplesSelection, IEnumerable<IClassificationSolution> solutions, int curRow) {182 if (samplesSelection == SamplesComboBoxAllSamples)183 return solutions;184 else if (samplesSelection == SamplesComboBoxTrainingSamples)185 return solutions.Where(s => s.ProblemData.IsTrainingSample(curRow));186 else if (samplesSelection == SamplesComboBoxTestSamples)187 return solutions.Where(s => s.ProblemData.IsTestSample(curRow));188 else189 return new List<IClassificationSolution>();190 }191 192 179 private double CalculateAreaUnderCurve(Series series) { 193 180 if (series.Points.Count < 1) throw new ArgumentException("Could not calculate area under curve if less than 1 data points were given."); -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationEnsembleSolutionEstimatedClassValuesView.cs
r8811 r8814 112 112 113 113 double[] confidences = null; 114 if (SamplesComboBox.SelectedItem.ToString() == SamplesComboBoxAllSamples) { 115 confidences = weightCalc.GetConfidence(solutions, indizes, estimatedClassValues).ToArray(); 114 115 if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxAllSamples)) { 116 confidences = weightCalc.GetConfidence(solutions, 117 indizes, 118 estimatedClassValues, 119 weightCalc.GetAllClassDelegate()).ToArray(); 120 } else if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTrainingSamples)) { 121 confidences = weightCalc.GetConfidence(solutions, 122 indizes, 123 estimatedClassValues, 124 weightCalc.GetTrainingClassDelegate()).ToArray(); 125 } else if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTestSamples)) { 126 confidences = weightCalc.GetConfidence(solutions, 127 indizes, 128 estimatedClassValues, 129 weightCalc.GetTestClassDelegate()).ToArray(); 116 130 } 117 131 … … 125 139 correctClassified = target[i].IsAlmost(estimatedClassValues[i]); 126 140 values[i, 3] = correctClassified.ToString(); 127 if (SamplesComboBox.SelectedItem.ToString() == SamplesComboBoxAllSamples) { 128 curConfidence = confidences[i]; 129 } else { 130 curConfidence = weightCalc.GetConfidence(GetRelevantSolutions(SamplesComboBox.SelectedItem.ToString(), solutions, row), 131 indizes[i], estimatedClassValues[i]); 132 } 141 curConfidence = confidences[i]; 133 142 if (correctClassified) { 134 143 confidence[0] += curConfidence; … … 167 176 matrix.SortableView = true; 168 177 matrixView.Content = matrix; 169 }170 171 protected IEnumerable<IClassificationSolution> GetRelevantSolutions(string samplesSelection, IEnumerable<IClassificationSolution> solutions, int curRow) {172 if (samplesSelection == SamplesComboBoxAllSamples)173 return solutions;174 else if (samplesSelection == SamplesComboBoxTrainingSamples)175 return solutions.Where(s => s.ProblemData.IsTrainingSample(curRow));176 else if (samplesSelection == SamplesComboBoxTestSamples)177 return solutions.Where(s => s.ProblemData.IsTestSample(curRow));178 else179 return new List<IClassificationSolution>();180 }181 182 private IEnumerable<int> FindAllIndices(List<double?> list, double value) {183 List<int> indices = new List<int>();184 for (int i = 0; i < list.Count; i++) {185 if (list[i].Equals(value))186 indices.Add(i);187 }188 return indices;189 178 } 190 179 -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/HeuristicLab.Problems.DataAnalysis.Views-3.4.csproj
r8811 r8814 214 214 <Compile Include="FeatureCorrelation\AbstractFeatureCorrelationView.Designer.cs"> 215 215 <DependentUpon>AbstractFeatureCorrelationView.cs</DependentUpon> 216 </Compile>217 <Compile Include="Classification\ClassificationEnsembleSolutionConfidenceAccuracyDependence.cs">218 <SubType>UserControl</SubType>219 </Compile>220 <Compile Include="Classification\ClassificationEnsembleSolutionConfidenceAccuracyDependence.Designer.cs">221 <DependentUpon>ClassificationEnsembleSolutionConfidenceAccuracyDependence.cs</DependentUpon>222 216 </Compile> 223 217 <Compile Include="Classification\ClassificationEnsembleSolutionAccuracyToCoveredSamples.cs"> -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/HeuristicLab.Problems.DataAnalysis-3.4.csproj
r8811 r8814 94 94 <ItemGroup> 95 95 <Reference Include="ALGLIB-3.6.0, Version=3.6.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL"> 96 <HintPath> >..\..\..\..\trunk\sources\bin\ALGLIB-3.6.0.dll</HintPath>96 <HintPath>>..\..\..\..\trunk\sources\bin\ALGLIB-3.6.0.dll</HintPath> 97 97 <Private>False</Private> 98 98 </Reference> 99 99 <Reference Include="HeuristicLab.ALGLIB-3.6.0, Version=3.6.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL"> 100 <HintPath> >..\..\..\..\trunk\sources\bin\HeuristicLab.ALGLIB-3.6.0.dll</HintPath>100 <HintPath>>..\..\..\..\trunk\sources\bin\HeuristicLab.ALGLIB-3.6.0.dll</HintPath> 101 101 <Private>False</Private> 102 102 </Reference> -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/AverageThresholdCalculator.cs
r8534 r8814 57 57 } 58 58 59 protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue ) {59 protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) { 60 60 Dataset dataset = solutions.First().ProblemData.Dataset; 61 IList<double> values = solutions. Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();61 IList<double> values = solutions.Where(s => handler(s.ProblemData, index)).Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList(); 62 62 if (values.Count <= 0) 63 63 return double.NaN; … … 66 66 } 67 67 68 public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue ) {68 public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) { 69 69 Dataset dataset = solutions.First().ProblemData.Dataset; 70 double[][] values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, indices).ToArray()).ToArray(); 70 List<int> indicesList = indices.ToList(); 71 var solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedValues(dataset, indicesList).ToArray()); 71 72 double[] confidences = new double[indices.Count()]; 72 73 double[] estimatedClassValueArr = estimatedClassValue.ToArray(); 73 74 74 for (int i = 0; i < indices.Count(); i++) { 75 double avg = values.Select(x => x[i]).Average(); 76 confidences[i] = GetAverageConfidence(avg, estimatedClassValueArr[i]); 75 for (int i = 0; i < indicesList.Count; i++) { 76 var values = solValues.Where(x => handler(x.Key.ProblemData, indicesList[i])).Select(x => x.Value[i]); 77 if (values.Count() <= 0) { 78 confidences[i] = double.NaN; 79 } else { 80 double avg = values.Average(); 81 confidences[i] = GetAverageConfidence(avg, estimatedClassValueArr[i]); 82 } 77 83 } 78 84 … … 84 90 if (estimatedClassValue.Equals(classValues[i])) { 85 91 //special case: avgerage is higher than value of highest class 86 if (i == classValues.Length - 1 && avg > estimatedClassValue) {92 if (i == classValues.Length - 1 && avg >= estimatedClassValue) { 87 93 return 1; 88 94 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs
r8811 r8814 24 24 using HeuristicLab.Common; 25 25 using HeuristicLab.Core; 26 using HeuristicLab.Data;27 26 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 27 … … 112 111 } 113 112 114 public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue ) {113 public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) { 115 114 if (solutions.Count() < 1) 116 115 return double.NaN; 117 116 Dataset dataset = solutions.First().ProblemData.Dataset; 118 117 var correctSolutions = solutions.Select(s => new { Solution = s, Values = s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First() }) 119 .Where(a => a.Values.Equals(estimatedClassValue))118 .Where(a => handler(a.Solution.ProblemData, index) && a.Values.Equals(estimatedClassValue)) 120 119 .Select(a => a.Solution); 121 120 return (from sol in correctSolutions … … 123 122 } 124 123 125 public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue ) {124 public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) { 126 125 if (solutions.Count() < 1) 127 126 return Enumerable.Repeat(double.NaN, indices.Count()); 128 127 128 List<int> indicesList = indices.ToList(); 129 129 130 Dataset dataset = solutions.First().ProblemData.Dataset; 130 Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indices ).ToArray());131 Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indicesList).ToArray()); 131 132 double[] estimatedClassValueArr = estimatedClassValue.ToArray(); 132 double[] confidences = new double[indices .Count()];133 double[] confidences = new double[indicesList.Count]; 133 134 134 for (int i = 0; i < indices .Count(); i++) {135 for (int i = 0; i < indicesList.Count; i++) { 135 136 var correctSolutions = solValues.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i])); 136 137 confidences[i] = (from sol in correctSolutions 138 where handler(sol.Key.ProblemData, indicesList[i]) 137 139 select weights[sol.Key]).Sum(); 138 140 } … … 142 144 143 145 #region Helper 144 protected IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {145 return from i in indizes146 select targetValues[i];147 }148 146 protected bool PointInTraining(IClassificationProblemData problemData, int point) { 149 IntRange trainingPartition = problemData.TrainingPartition; 150 IntRange testPartition = problemData.TestPartition; 151 return (trainingPartition.Start <= point && point < trainingPartition.End) 152 && !(testPartition.Start <= point && point < testPartition.End); 147 return problemData.IsTrainingSample(point); 153 148 } 154 149 protected bool PointInTest(IClassificationProblemData problemData, int point) { 155 IntRange testPartition = problemData.TestPartition; 156 return testPartition.Start <= point && point < testPartition.End; 150 return problemData.IsTestSample(point); 157 151 } 158 152 protected bool AllPoints(IClassificationProblemData problemData, int point) { -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs
r8811 r8814 82 82 } 83 83 84 public sealed override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue ) {84 public sealed override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) { 85 85 if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution)) 86 86 return double.NaN; … … 88 88 IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>(); 89 89 90 return GetDiscriminantConfidence(discriminantSolutions, index, estimatedClassValue );90 return GetDiscriminantConfidence(discriminantSolutions, index, estimatedClassValue, handler); 91 91 } 92 92 93 protected virtual double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue ) {94 return base.GetConfidence(solutions, index, estimatedClassValue );93 protected virtual double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) { 94 return base.GetConfidence(solutions, index, estimatedClassValue, handler); 95 95 } 96 96 97 public sealed override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue ) {97 public sealed override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) { 98 98 if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution)) 99 99 return Enumerable.Repeat(double.NaN, indices.Count()); … … 101 101 IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>(); 102 102 103 return GetDiscriminantConfidence(discriminantSolutions, indices, estimatedClassValue );103 return GetDiscriminantConfidence(discriminantSolutions, indices, estimatedClassValue, handler); 104 104 } 105 105 106 public virtual IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue ) {107 return base.GetConfidence(solutions, indices, estimatedClassValue );106 public virtual IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) { 107 return base.GetConfidence(solutions, indices, estimatedClassValue, handler); 108 108 } 109 109 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MajorityVoteWeightCalculator.cs
r8297 r8814 52 52 } 53 53 54 public override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue ) {54 public override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) { 55 55 if (solutions.Count() < 1) 56 56 return double.NaN; 57 Dataset dataset = solutions.First().ProblemData.Dataset; 58 int correctEstimated = solutions.Select(s => s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First()) 57 var votingSolutions = solutions.Where(s => handler(s.ProblemData, index)); 58 if (votingSolutions.Count() < 1) 59 return double.NaN; 60 Dataset dataset = votingSolutions.First().ProblemData.Dataset; 61 int correctEstimated = votingSolutions.Select(s => s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First()) 59 62 .Where(x => x.Equals(estimatedClassValue)) 60 63 .Count(); 61 return ((double)correctEstimated / (double) solutions.Count() - 0.5) * 2;64 return ((double)correctEstimated / (double)votingSolutions.Count() - 0.5) * 2; 62 65 } 63 66 64 public override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue ) {67 public override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) { 65 68 if (solutions.Count() < 1) 66 69 return Enumerable.Repeat(double.NaN, indices.Count()); 70 71 List<int> indicesList = indices.ToList(); 67 72 Dataset dataset = solutions.First().ProblemData.Dataset; 68 var estimationsPerSolution = solutions.Select(s => s.Model.GetEstimatedClassValues(dataset, indices).ToArray()).ToArray();73 var solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indicesList).ToArray()); 69 74 double[] estimatedClassValueArr = estimatedClassValue.ToArray(); 70 75 double correctEstimated; 71 76 double[] confidences = new double[indices.Count()]; 72 77 73 for (int i = 0; i < indices.Count(); i++) { 74 correctEstimated = estimationsPerSolution.Where(x => DoubleExtensions.IsAlmost(x[i], estimatedClassValueArr[i])).Count(); 75 confidences[i] = (correctEstimated / (double)solutions.Count() - 0.5) * 2; 78 for (int i = 0; i < indicesList.Count; i++) { 79 var votingSolutions = solValues.Where(x => handler(x.Key.ProblemData, indicesList[i])); 80 correctEstimated = votingSolutions.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i])).Count(); 81 confidences[i] = (correctEstimated / (double)votingSolutions.Count() - 0.5) * 2; 76 82 } 77 83 -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MedianThresholdCalculator.cs
r8534 r8814 57 57 } 58 58 59 protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) { 60 59 protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) { 61 60 Dataset dataset = solutions.First().ProblemData.Dataset; 62 IList<double> values = solutions. Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();61 IList<double> values = solutions.Where(s => handler(s.ProblemData, index)).Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList(); 63 62 if (values.Count <= 0) 64 63 return double.NaN; … … 67 66 } 68 67 69 public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue ) {68 public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) { 70 69 71 70 Dataset dataset = solutions.First().ProblemData.Dataset; 72 double[][] values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, indices).ToArray()).ToArray(); 71 List<int> indicesList = indices.ToList(); 72 var solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedValues(dataset, indicesList).ToArray()); 73 73 double[] confidences = new double[indices.Count()]; 74 74 double[] estimatedClassValueArr = estimatedClassValue.ToArray(); 75 75 76 for (int i = 0; i < indices.Count(); i++) { 77 double avg = values.Select(x => x[i]).Average(); 78 confidences[i] = GetMedianConfidence(avg, estimatedClassValueArr[i]); 76 for (int i = 0; i < indicesList.Count; i++) { 77 var values = solValues.Where(x => handler(x.Key.ProblemData, indicesList[i])).Select(x => x.Value[i]).ToList(); 78 if (values.Count <= 0) { 79 confidences[i] = double.NaN; 80 } else { 81 double median = GetMedian(values); 82 confidences[i] = GetMedianConfidence(median, estimatedClassValueArr[i]); 83 } 79 84 } 80 85 … … 85 90 for (int i = 0; i < classValues.Length; i++) { 86 91 if (estimatedClassValue.Equals(classValues[i])) { 87 //special case: avgerageis higher than value of highest class88 if (i == classValues.Length - 1 && median > estimatedClassValue) {92 //special case: median is higher than value of highest class 93 if (i == classValues.Length - 1 && median >= estimatedClassValue) { 89 94 return 1; 90 95 } 91 //special case: averageis lower than value of lowest class96 //special case: median is lower than value of lowest class 92 97 if (i == 0 && median < estimatedClassValue) { 93 98 return 1; 94 99 } 95 //special case: averageis not between threshold of estimated class value100 //special case: median is not between threshold of estimated class value 96 101 if ((i < classValues.Length - 1 && median >= threshold[i + 1]) || median <= threshold[i]) { 97 102 return 0; 98 103 } 99 104 100 double thresholdToClassDistance, thresholdTo AverageValueDistance;105 double thresholdToClassDistance, thresholdToMedianValueDistance; 101 106 if (median >= classValues[i]) { 102 107 thresholdToClassDistance = threshold[i + 1] - classValues[i]; 103 thresholdTo AverageValueDistance = threshold[i + 1] - median;108 thresholdToMedianValueDistance = threshold[i + 1] - median; 104 109 } else { 105 110 thresholdToClassDistance = classValues[i] - threshold[i]; 106 thresholdTo AverageValueDistance = median - threshold[i];111 thresholdToMedianValueDistance = median - threshold[i]; 107 112 } 108 return (1 / thresholdToClassDistance) * thresholdTo AverageValueDistance;113 return (1 / thresholdToClassDistance) * thresholdToMedianValueDistance; 109 114 } 110 115 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationEnsembleSolutionWeightCalculator.cs
r8811 r8814 29 29 void CalculateNormalizedWeights(IEnumerable<IClassificationSolution> classificationSolutions); 30 30 IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler); 31 double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue );32 IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue );31 double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler); 32 IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler); 33 33 34 34 CheckPoint GetTestClassDelegate();
Note: See TracChangeset
for help on using the changeset viewer.