Changeset 7562
- Timestamp:
- 03/06/12 15:08:13 (13 years ago)
- Location:
- branches/ClassificationEnsembleVoting
- Files:
-
- 7 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationEnsembleSolutionEstimatedClassValuesView.cs
r7531 r7562 28 28 using HeuristicLab.MainForm; 29 29 using HeuristicLab.MainForm.WindowsForms; 30 using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification; 30 31 31 32 namespace HeuristicLab.Problems.DataAnalysis.Views { … … 96 97 } 97 98 99 IEnumerable<IClassificationSolution> solutions = Content.ClassificationSolutions.CheckedItems; 98 100 int classValuesCount = Content.ProblemData.ClassValues.Count; 99 int solutionsCount = Content.ClassificationSolutions.Count();101 int solutionsCount = solutions.Count(); 100 102 string[,] values = new string[indizes.Length, 5 + classValuesCount + solutionsCount]; 101 103 double[] target = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToArray(); 102 104 List<List<double?>> estimatedValuesVector = GetEstimatedValues(SamplesComboBox.SelectedItem.ToString(), indizes, 103 Content.ClassificationSolutions);104 //List<double> weights = Content.Weights.ToList(); 105 //double weightSum = weights.Sum();105 solutions); 106 107 IClassificationEnsembleSolutionWeightCalculator weightCalc = Content.WeightCalculator; 106 108 107 109 for (int i = 0; i < indizes.Length; i++) { … … 113 115 values[i, 2] = estimatedClassValues[i].ToString(); 114 116 values[i, 3] = (target[i].IsAlmost(estimatedClassValues[i])).ToString(); 115 116 //currently disabled for test purpose 117 118 //IEnumerable<int> indices = FindAllIndices(estimatedValuesVector[i], estimatedClassValues[i]); 119 //double confidence = 0.0; 120 //foreach (var index in indices) { 121 // confidence += weights[index]; 122 //} 123 //values[i, 4] = (confidence / weightSum).ToString(); 124 //var estimationCount = groups.Where(g => g.Key != null).Select(g => g.Count).Sum(); 125 //values[i, 4] = 126 // (((double)groups.Where(g => g.Key == estimatedClassValues[i]).Single().Count) / estimationCount).ToString(); 127 values[i, 4] = "1.0"; 117 values[i, 4] = weightCalc.GetConfidence(solutions, indizes[i], estimatedClassValues[i]).ToString(); 128 118 129 119 var groups = … … 145 135 List<string> columnNames = new List<string>() { "Id", TargetClassValuesColumnName, EstimatedClassValuesColumnName, CorrectClassificationColumnName, ConfidenceColumnName }; 146 136 columnNames.AddRange(Content.ProblemData.ClassNames); 147 columnNames.AddRange(Content. Model.Models.Select(m => m.Name));137 columnNames.AddRange(Content.ClassificationSolutions.CheckedItems.Select(s => s.Model.Name));//.Model.Models.Select(m => m.Name)); 148 138 matrix.ColumnNames = columnNames; 149 139 matrix.SortableView = true; -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
r7549 r7562 51 51 } 52 52 53 //[Storable]54 //private Dictionary<IClassificationModel, IntRange> trainingPartitions;55 //[Storable]56 //private Dictionary<IClassificationModel, IntRange> testPartitions;57 58 53 private IClassificationEnsembleSolutionWeightCalculator weightCalculator; 59 54 … … 66 61 } 67 62 } 63 get { return weightCalculator; } 68 64 } 69 65 … … 77 73 foreach (var model in Model.Models) { 78 74 IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone(); 79 //problemData.TrainingPartition.Start = trainingPartitions[model].Start;80 //problemData.TrainingPartition.End = trainingPartitions[model].End;81 //problemData.TestPartition.Start = testPartitions[model].Start;82 //problemData.TestPartition.End = testPartitions[model].End;83 84 75 classificationSolutions.Add(model.CreateClassificationSolution(problemData)); 85 76 } … … 89 80 private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner) 90 81 : base(original, cloner) { 91 //trainingPartitions = new Dictionary<IClassificationModel, IntRange>();92 //testPartitions = new Dictionary<IClassificationModel, IntRange>();93 //foreach (var pair in original.trainingPartitions) {94 // trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);95 //}96 //foreach (var pair in original.testPartitions) {97 // testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);98 //}99 100 82 classificationSolutions = cloner.Clone(original.classificationSolutions); 101 83 RegisterClassificationSolutionsEventHandler(); … … 104 86 public ClassificationEnsembleSolution() 105 87 : base(new ClassificationEnsembleModel(), ClassificationEnsembleProblemData.EmptyProblemData) { 106 //trainingPartitions = new Dictionary<IClassificationModel, IntRange>();107 //testPartitions = new Dictionary<IClassificationModel, IntRange>();108 88 classificationSolutions = new CheckedItemCollection<IClassificationSolution>(); 109 89 … … 121 101 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions) 122 102 : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) { 123 //this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>();124 //this.testPartitions = new Dictionary<IClassificationModel, IntRange>();125 103 this.classificationSolutions = new CheckedItemCollection<IClassificationSolution>(); 126 104 … … 217 195 solution.ProblemData = problemData; 218 196 } 219 //foreach (var trainingPartition in trainingPartitions.Values) {220 // trainingPartition.Start = ProblemData.TrainingPartition.Start;221 // trainingPartition.End = ProblemData.TrainingPartition.End;222 //}223 //foreach (var testPartition in testPartitions.Values) {224 // testPartition.Start = ProblemData.TestPartition.Start;225 // testPartition.End = ProblemData.TestPartition.End;226 //}227 228 197 base.OnProblemDataChanged(); 229 198 } … … 256 225 if (Model.Models.Contains(solution.Model)) throw new ArgumentException(); 257 226 Model.Add(solution.Model); 258 //trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;259 //testPartitions[solution.Model] = solution.ProblemData.TestPartition;260 227 } 261 228 … … 263 230 if (!Model.Models.Contains(solution.Model)) throw new ArgumentException(); 264 231 Model.Remove(solution.Model); 265 //trainingPartitions.Remove(solution.Model);266 //testPartitions.Remove(solution.Model);267 232 } 268 233 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs
r7559 r7562 113 113 } 114 114 115 public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue) { 116 if (solutions.Count() < 1) 117 return double.NaN; 118 Dataset dataset = solutions.First().ProblemData.Dataset; 119 var correctSolutions = solutions.Select(s => new { Solution = s, Values = s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First() }) 120 .Where(a => a.Values.Equals(estimatedClassValue)) 121 .Select(a => a.Solution); 122 return (from sol in correctSolutions 123 select weights[sol]).Sum(); 124 } 125 115 126 #region Helper 116 127 protected IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) { -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs
r7559 r7562 82 82 } 83 83 } 84 85 public sealed override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue) { 86 if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution)) 87 return double.NaN; 88 89 IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>(); 90 91 return GetDiscriminantConfidence(discriminantSolutions, index, estimatedClassValue); 92 } 93 94 protected virtual double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) { 95 return base.GetConfidence(solutions, index, estimatedClassValue); 96 } 84 97 } 85 98 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MedianThresholdCalculator.cs
r7549 r7562 20 20 #endregion 21 21 22 using System.Collections; 22 23 using System.Collections.Generic; 23 24 using System.Linq; … … 84 85 } 85 86 AccuracyMaximizationThresholdCalculator.CalculateThresholds(solutionProblemData[0], median, targetValues, out classValues, out threshold); 86 return median;87 return Enumerable.Repeat<double>(1, discriminantSolutions.Count()); 87 88 } 88 89 89 90 protected override double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IClassificationSolution, double> estimatedValues) { 90 double classValue = classValues.First();91 91 IList<double> values = estimatedValues.Select(x => x.Value).ToList(); 92 92 if (values.Count <= 0) 93 93 return double.NaN; 94 94 double median = GetMedian(values); 95 return GetClassValueToMedian(median); 96 } 97 private double GetClassValueToMedian(double median) { 98 double classValue = classValues.First(); 95 99 for (int i = 0; i < classValues.Count(); i++) { 96 100 if (median > threshold[i]) … … 100 104 } 101 105 return classValue; 106 } 107 108 protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) { 109 // only works with binary classification 110 if (!classValues.Count().Equals(2)) 111 return double.NaN; 112 Dataset dataset = solutions.First().ProblemData.Dataset; 113 IList<double> values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList(); 114 if (values.Count <= 0) 115 return double.NaN; 116 double median = GetMedian(values); 117 if (estimatedClassValue.Equals(classValues[0])) { 118 if (median < estimatedClassValue) 119 return 1; 120 else if (median >= threshold[1]) 121 return 0; 122 else { 123 double distance = threshold[1] - classValues[0]; 124 return (1 / distance) * (median - classValues[0]); 125 } 126 } else if (estimatedClassValue.Equals(classValues[1])) { 127 if (median > estimatedClassValue) 128 return 1; 129 else if (median <= threshold[1]) 130 return 0; 131 else { 132 double distance = classValues[1] - threshold[1]; 133 return (1 / distance) * (classValues[1] - median); 134 } 135 } else 136 return double.NaN; 102 137 } 103 138 -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/NeighbourhoodWeightCalculator.cs
r7549 r7562 58 58 foreach (var solution in discriminantSolutions) { 59 59 estimatedValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList()); 60 estimatedClassValues.Add(solution.Model.GetEstimated Values(dataSet, rows).ToList());60 estimatedClassValues.Add(solution.Model.GetEstimatedClassValues(dataSet, rows).ToList()); 61 61 } 62 62 -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationEnsembleSolutionWeightCalculator.cs
r7549 r7562 25 25 namespace HeuristicLab.Problems.DataAnalysis.Interfaces.Classification { 26 26 public delegate bool CheckPoint(IClassificationProblemData problemData, int point); 27 27 28 public interface IClassificationEnsembleSolutionWeightCalculator : INamedItem { 28 29 void CalculateNormalizedWeights(IEnumerable<IClassificationSolution> classificationSolutions); 29 30 IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler); 31 double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue); 32 30 33 CheckPoint GetTestClassDelegate(); 31 34 CheckPoint GetTrainingClassDelegate();
Note: See TracChangeset
for help on using the changeset viewer.