Changeset 7549 for branches/ClassificationEnsembleVoting
- Timestamp:
- 03/05/12 17:02:37 (13 years ago)
- Location:
- branches/ClassificationEnsembleVoting
- Files:
-
- 13 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationEnsembleSolutionModelView.cs
r7259 r7549 58 58 } 59 59 60 private class ModelsView : ItemCollectionView<IClassificationSolution> {60 private class ModelsView : CheckedItemCollectionView<IClassificationSolution> { 61 61 protected override void SetEnabledStateOfControls() { 62 62 base.SetEnabledStateOfControls(); -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/Solution Views/ClassificationEnsembleSolutionView.cs
r7464 r7549 58 58 } 59 59 60 pr ivatevoid cmbWeightCalculator_SelectedIndexChanged(object sender, System.EventArgs e) {60 protected void cmbWeightCalculator_SelectedIndexChanged(object sender, System.EventArgs e) { 61 61 if (cmbWeightCalculator.SelectedItem != null) 62 62 Content.WeightCalculator = (IClassificationEnsembleSolutionWeightCalculator)cmbWeightCalculator.SelectedItem; -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
r7531 r7549 46 46 } 47 47 48 private readonly ItemCollection<IClassificationSolution> classificationSolutions;49 public I ItemCollection<IClassificationSolution> ClassificationSolutions {48 private readonly CheckedItemCollection<IClassificationSolution> classificationSolutions; 49 public ICheckedItemCollection<IClassificationSolution> ClassificationSolutions { 50 50 get { return classificationSolutions; } 51 51 } 52 52 53 [Storable]54 private Dictionary<IClassificationModel, IntRange> trainingPartitions;55 [Storable]56 private Dictionary<IClassificationModel, IntRange> testPartitions;53 //[Storable] 54 //private Dictionary<IClassificationModel, IntRange> trainingPartitions; 55 //[Storable] 56 //private Dictionary<IClassificationModel, IntRange> testPartitions; 57 57 58 58 private IClassificationEnsembleSolutionWeightCalculator weightCalculator; … … 62 62 if (value != null) { 63 63 weightCalculator = value; 64 weightCalculator.CalculateNormalizedWeights(classificationSolutions);65 64 if (!ProblemData.IsEmpty) 66 65 RecalculateResults(); … … 72 71 private ClassificationEnsembleSolution(bool deserializing) 73 72 : base(deserializing) { 74 classificationSolutions = new ItemCollection<IClassificationSolution>();73 classificationSolutions = new CheckedItemCollection<IClassificationSolution>(); 75 74 } 76 75 [StorableHook(HookType.AfterDeserialization)] … … 78 77 foreach (var model in Model.Models) { 79 78 IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone(); 80 problemData.TrainingPartition.Start = trainingPartitions[model].Start;81 problemData.TrainingPartition.End = trainingPartitions[model].End;82 problemData.TestPartition.Start = testPartitions[model].Start;83 problemData.TestPartition.End = testPartitions[model].End;79 //problemData.TrainingPartition.Start = trainingPartitions[model].Start; 80 //problemData.TrainingPartition.End = trainingPartitions[model].End; 81 //problemData.TestPartition.Start = testPartitions[model].Start; 82 //problemData.TestPartition.End = testPartitions[model].End; 84 83 85 84 classificationSolutions.Add(model.CreateClassificationSolution(problemData)); … … 90 89 private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner) 91 90 : base(original, cloner) { 92 trainingPartitions = new Dictionary<IClassificationModel, IntRange>();93 testPartitions = new Dictionary<IClassificationModel, IntRange>();94 foreach (var pair in original.trainingPartitions) {95 trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);96 }97 foreach (var pair in original.testPartitions) {98 testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);99 }91 //trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 92 //testPartitions = new Dictionary<IClassificationModel, IntRange>(); 93 //foreach (var pair in original.trainingPartitions) { 94 // trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value); 95 //} 96 //foreach (var pair in original.testPartitions) { 97 // testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value); 98 //} 100 99 101 100 classificationSolutions = cloner.Clone(original.classificationSolutions); … … 105 104 public ClassificationEnsembleSolution() 106 105 : base(new ClassificationEnsembleModel(), ClassificationEnsembleProblemData.EmptyProblemData) { 107 trainingPartitions = new Dictionary<IClassificationModel, IntRange>();108 testPartitions = new Dictionary<IClassificationModel, IntRange>();109 classificationSolutions = new ItemCollection<IClassificationSolution>();106 //trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 107 //testPartitions = new Dictionary<IClassificationModel, IntRange>(); 108 classificationSolutions = new CheckedItemCollection<IClassificationSolution>(); 110 109 111 110 weightCalculator = new MajorityVoteWeightCalculator(); … … 122 121 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions) 123 122 : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) { 124 this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>();125 this.testPartitions = new Dictionary<IClassificationModel, IntRange>();126 this.classificationSolutions = new ItemCollection<IClassificationSolution>();123 //this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 124 //this.testPartitions = new Dictionary<IClassificationModel, IntRange>(); 125 this.classificationSolutions = new CheckedItemCollection<IClassificationSolution>(); 127 126 128 127 List<IClassificationSolution> solutions = new List<IClassificationSolution>(); … … 155 154 classificationSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsRemoved); 156 155 classificationSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CollectionReset); 156 classificationSolutions.CheckedItemsChanged += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CheckedItemsChanged); 157 157 } 158 158 159 159 protected override void RecalculateResults() { 160 weightCalculator.CalculateNormalizedWeights(classificationSolutions.CheckedItems); 160 161 CalculateResults(); 161 162 } … … 163 164 #region Evaluation 164 165 public override IEnumerable<double> EstimatedTrainingClassValues { 165 get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TrainingIndizes); } 166 get { 167 return weightCalculator.AggregateEstimatedClassValues(classificationSolutions.CheckedItems, 168 ProblemData.Dataset, 169 ProblemData.TrainingIndizes, 170 weightCalculator.GetTrainingClassDelegate()); 171 } 166 172 } 167 173 168 174 public override IEnumerable<double> EstimatedTestClassValues { 169 get { return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, ProblemData.TestIndizes); } 170 } 171 172 private bool RowIsTrainingForModel(int currentRow, IClassificationModel model) { 173 return trainingPartitions == null || !trainingPartitions.ContainsKey(model) || 174 (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End); 175 } 176 177 private bool RowIsTestForModel(int currentRow, IClassificationModel model) { 178 return testPartitions == null || !testPartitions.ContainsKey(model) || 179 (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End); 175 get { 176 return weightCalculator.AggregateEstimatedClassValues(classificationSolutions.CheckedItems, 177 ProblemData.Dataset, 178 ProblemData.TestIndizes, 179 weightCalculator.GetTestClassDelegate()); 180 } 180 181 } 181 182 182 183 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 183 return weightCalculator.AggregateEstimatedClassValues(Model.Models, ProblemData.Dataset, rows); 184 return weightCalculator.AggregateEstimatedClassValues(classificationSolutions.CheckedItems, 185 ProblemData.Dataset, 186 rows, 187 weightCalculator.GetAllClassDelegate()); 184 188 } 185 189 186 190 public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) { 187 if (!Model.Models.Any()) yield break; 188 var estimatedValuesEnumerators = (from model in Model.Models 191 IEnumerable<IClassificationModel> models = classificationSolutions.CheckedItems.Select(sol => sol.Model); 192 if (!models.Any()) yield break; 193 var estimatedValuesEnumerators = (from model in models 189 194 select model.GetEstimatedClassValues(dataset, rows).GetEnumerator()) 190 195 .ToList(); … … 212 217 solution.ProblemData = problemData; 213 218 } 214 foreach (var trainingPartition in trainingPartitions.Values) {215 trainingPartition.Start = ProblemData.TrainingPartition.Start;216 trainingPartition.End = ProblemData.TrainingPartition.End;217 }218 foreach (var testPartition in testPartitions.Values) {219 testPartition.Start = ProblemData.TestPartition.Start;220 testPartition.End = ProblemData.TestPartition.End;221 }219 //foreach (var trainingPartition in trainingPartitions.Values) { 220 // trainingPartition.Start = ProblemData.TrainingPartition.Start; 221 // trainingPartition.End = ProblemData.TrainingPartition.End; 222 //} 223 //foreach (var testPartition in testPartitions.Values) { 224 // testPartition.Start = ProblemData.TestPartition.Start; 225 // testPartition.End = ProblemData.TestPartition.End; 226 //} 222 227 223 228 base.OnProblemDataChanged(); … … 244 249 RecalculateResults(); 245 250 } 251 private void classificationSolutions_CheckedItemsChanged(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) { 252 RecalculateResults(); 253 } 246 254 247 255 private void AddClassificationSolution(IClassificationSolution solution) { 248 256 if (Model.Models.Contains(solution.Model)) throw new ArgumentException(); 249 257 Model.Add(solution.Model); 250 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 251 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 252 weightCalculator.CalculateNormalizedWeights(classificationSolutions); 258 //trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 259 //testPartitions[solution.Model] = solution.ProblemData.TestPartition; 253 260 } 254 261 … … 256 263 if (!Model.Models.Contains(solution.Model)) throw new ArgumentException(); 257 264 Model.Remove(solution.Model); 258 trainingPartitions.Remove(solution.Model); 259 testPartitions.Remove(solution.Model); 260 weightCalculator.CalculateNormalizedWeights(classificationSolutions); 265 //trainingPartitions.Remove(solution.Model); 266 //testPartitions.Remove(solution.Model); 261 267 } 262 268 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/AccuracyWeightCalculator.cs
r7531 r7549 46 46 } 47 47 48 protected override IEnumerable<double> CalculateWeights(I temCollection<IClassificationSolution> classificationSolutions) {48 protected override IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions) { 49 49 return classificationSolutions.Select(s => s.TrainingAccuracy); 50 50 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs
r7531 r7549 20 20 #endregion 21 21 22 using System;23 22 using System.Collections.Generic; 24 23 using System.Linq; 25 24 using HeuristicLab.Common; 26 25 using HeuristicLab.Core; 26 using HeuristicLab.Data; 27 27 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 28 using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification; … … 45 45 } 46 46 47 private I Enumerable<double> weights;47 private IDictionary<IClassificationSolution, double> weights; 48 48 49 49 /// <summary> … … 52 52 /// <param name="classificationSolutions"></param> 53 53 /// <returns>weights which are equal or bigger than zero</returns> 54 public void CalculateNormalizedWeights(I temCollection<IClassificationSolution> classificationSolutions) {54 public void CalculateNormalizedWeights(IEnumerable<IClassificationSolution> classificationSolutions) { 55 55 List<double> weights = new List<double>(); 56 if (classificationSolutions.Count > 0) {56 if (classificationSolutions.Count() > 0) { 57 57 foreach (var weight in CalculateWeights(classificationSolutions)) { 58 58 weights.Add(weight >= 0 ? weight : 0); 59 59 } 60 60 } 61 this.weights = weights.Select(x => x / weights.Sum()); 61 double sum = weights.Sum(); 62 this.weights = classificationSolutions.Zip(weights, (sol, wei) => new { sol, wei }).ToDictionary(x => x.sol, x => x.wei / sum); 62 63 } 63 64 64 protected abstract IEnumerable<double> CalculateWeights(I temCollection<IClassificationSolution> classificationSolutions);65 protected abstract IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions); 65 66 66 public virtual IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationModel> models, Dataset dataset, IEnumerable<int> rows) { 67 return from xs in ClassificationWeightCalculator.GetEstimatedClassValues(models, dataset, rows) 67 #region delegate CheckPoint 68 public CheckPoint GetTestClassDelegate() { 69 return PointInTest; 70 } 71 public CheckPoint GetTrainingClassDelegate() { 72 return PointInTraining; 73 } 74 public CheckPoint GetAllClassDelegate() { 75 return AllPoints; 76 } 77 #endregion 78 79 public virtual IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) { 80 return from xs in GetEstimatedClassValues(solutions, dataset, rows, handler) 68 81 select AggregateEstimatedClassValues(xs); 69 82 } 70 83 71 protected double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) { 72 if (!estimatedClassValues.Count().Equals(weights.Count())) 73 throw new ArgumentException("'estimatedClassValues' has " + estimatedClassValues.Count() + " elements, while 'weights' has" + weights.Count()); 84 protected double AggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues) { 74 85 IDictionary<double, double> weightSum = new Dictionary<double, double>(); 75 for (int i = 0; i < estimatedClassValues.Count(); i++) {76 if (!weightSum.ContainsKey( estimatedClassValues.ElementAt(i)))77 weightSum[ estimatedClassValues.ElementAt(i)] = 0.0;78 weightSum[ estimatedClassValues.ElementAt(i)] += weights.ElementAt(i);86 foreach (var item in estimatedClassValues) { 87 if (!weightSum.ContainsKey(item.Value)) 88 weightSum[item.Value] = 0.0; 89 weightSum[item.Value] += weights[item.Key]; 79 90 } 80 91 if (weightSum.Count <= 0) … … 88 99 } 89 100 90 protected static IEnumerable<IEnumerable<double>> GetEstimatedClassValues(IEnumerable<IClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {91 if (! models.Any()) yield break;92 var estimatedValuesEnumerators = (from model in models93 select model.GetEstimatedClassValues(dataset, rows).GetEnumerator())101 protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) { 102 if (!solutions.Any()) yield break; 103 var estimatedValuesEnumerators = (from solution in solutions 104 select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() }) 94 105 .ToList(); 95 106 96 while (estimatedValuesEnumerators.All(en => en.MoveNext())) { 97 yield return from enumerator in estimatedValuesEnumerators 98 select enumerator.Current; 107 var rowEnumerator = rows.GetEnumerator(); 108 while (rowEnumerator.MoveNext() & estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) { 109 yield return (from enumerator in estimatedValuesEnumerators 110 where handler(enumerator.Solution.ProblemData, rowEnumerator.Current) 111 select enumerator) 112 .ToDictionary(x => x.Solution, x => x.EstimatedValuesEnumerator.Current); 99 113 } 100 114 } … … 105 119 select targetValues[i]; 106 120 } 121 protected bool PointInTraining(IClassificationProblemData problemData, int point) { 122 IntRange trainingPartition = problemData.TrainingPartition; 123 IntRange testPartition = problemData.TestPartition; 124 return (trainingPartition.Start <= point && point < trainingPartition.End) 125 && !(testPartition.Start <= point && point < testPartition.End); 126 } 127 protected bool PointInTest(IClassificationProblemData problemData, int point) { 128 IntRange testPartition = problemData.TestPartition; 129 return testPartition.Start <= point && point < testPartition.End; 130 } 131 protected bool AllPoints(IClassificationProblemData problemData, int point) { 132 return true; 133 } 107 134 #endregion 108 135 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ContinuousPointCertaintyWeightCalculator.cs
r7531 r7549 49 49 } 50 50 51 protected override IEnumerable<double> DiscriminantCalculateWeights(I temCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions) {51 protected override IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions) { 52 52 List<double> weights = new List<double>(); 53 53 IClassificationProblemData problemData = discriminantSolutions.ElementAt(0).ProblemData; 54 IEnumerable<double> targetValues = GetValues(problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(), problemData.TrainingIndizes);54 IEnumerable<double> targetValues; 55 55 IEnumerator<double> trainingValues; 56 56 57 57 //only works for binary classification 58 58 if (!problemData.ClassValues.Count().Equals(2)) 59 return Enumerable.Repeat<double>(1, discriminantSolutions.Count );59 return Enumerable.Repeat<double>(1, discriminantSolutions.Count()); 60 60 61 61 double maxClass = problemData.ClassValues.Max(); … … 64 64 65 65 foreach (var solution in discriminantSolutions) { 66 problemData = solution.ProblemData; 67 targetValues = GetValues(problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(), problemData.TrainingIndizes); 68 trainingValues = targetValues.GetEnumerator(); 69 66 70 IEnumerator<double> estimatedTrainingVal = solution.EstimatedTrainingValues.GetEnumerator(); 67 71 IEnumerator<double> estimatedTrainingClassVal = solution.EstimatedTrainingClassValues.GetEnumerator(); 68 72 69 trainingValues = targetValues.GetEnumerator();70 73 double curWeight = 0.0; 71 74 while (estimatedTrainingVal.MoveNext() && estimatedTrainingClassVal.MoveNext() && trainingValues.MoveNext()) { 72 //if (estimatedTrainingClassVal.Current.Equals(trainingValues.Current)) {73 75 if (trainingValues.Current.Equals(maxClass)) { 74 76 if (estimatedTrainingVal.Current >= maxClass) … … 86 88 } 87 89 } 88 //}89 90 } 90 weights.Add(curWeight); 91 // normalize the weight (otherwise a model with a bigger training partition would probably be better) 92 weights.Add(curWeight / targetValues.Count()); 91 93 } 92 94 return weights; -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs
r7531 r7549 23 23 using System.Linq; 24 24 using HeuristicLab.Common; 25 using HeuristicLab.Core;26 25 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 26 using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification; 27 27 28 28 namespace HeuristicLab.Problems.DataAnalysis { … … 41 41 } 42 42 43 protected override IEnumerable<double> CalculateWeights(I temCollection<IClassificationSolution> classificationSolutions) {43 protected override IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions) { 44 44 if (!classificationSolutions.All(x => x is IDiscriminantFunctionClassificationSolution)) 45 return Enumerable.Repeat<double>(1.0, classificationSolutions.Count );45 return Enumerable.Repeat<double>(1.0, classificationSolutions.Count()); 46 46 47 ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions = new ItemCollection<IDiscriminantFunctionClassificationSolution>(); 48 foreach (var solution in classificationSolutions) { 49 discriminantSolutions.Add((IDiscriminantFunctionClassificationSolution)solution); 50 } 47 IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = classificationSolutions.Cast<IDiscriminantFunctionClassificationSolution>(); 51 48 52 49 return DiscriminantCalculateWeights(discriminantSolutions); 53 50 } 54 51 55 protected abstract IEnumerable<double> DiscriminantCalculateWeights(I temCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions);52 protected abstract IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions); 56 53 57 public override IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassification Model> models, Dataset dataset, IEnumerable<int> rows) {58 if (! models.All(x => x is IDiscriminantFunctionClassificationModel))54 public override IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) { 55 if (!solutions.All(x => x is IDiscriminantFunctionClassificationSolution)) 59 56 return Enumerable.Repeat<double>(0.0, rows.Count()); 60 57 61 IEnumerable<IDiscriminantFunctionClassification Model> discriminantModels = models.Cast<IDiscriminantFunctionClassificationModel>();58 IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>(); 62 59 63 IEnumerable<I Enumerable<double>> estimatedClassValues = ClassificationWeightCalculator.GetEstimatedClassValues(models, dataset, rows);64 IEnumerable<I Enumerable<double>> estimatedValues = DiscriminantClassificationWeightCalculator.GetEstimatedValues(discriminantModels, dataset, rows);60 IEnumerable<IDictionary<IClassificationSolution, double>> estimatedClassValues = GetEstimatedClassValues(solutions, dataset, rows, handler); 61 IEnumerable<IDictionary<IClassificationSolution, double>> estimatedValues = GetEstimatedValues(discriminantSolutions, dataset, rows, handler); 65 62 66 63 return from zip in estimatedClassValues.Zip(estimatedValues, (classValues, values) => new { ClassValues = classValues, Values = values }) … … 68 65 } 69 66 70 protected virtual double DiscriminantAggregateEstimatedClassValues(I Enumerable<double> estimatedClassValues, IEnumerable<double> estimatedValues) {67 protected virtual double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IClassificationSolution, double> estimatedValues) { 71 68 return AggregateEstimatedClassValues(estimatedClassValues); 72 69 } 73 70 74 protected static IEnumerable<IEnumerable<double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationModel> models, Dataset dataset, IEnumerable<int> rows) {75 if (! models.Any()) yield break;76 var estimatedValuesEnumerators = (from model in models77 select model.GetEstimatedValues(dataset, rows).GetEnumerator())78 .ToList();71 protected IEnumerable<IDictionary<IClassificationSolution, double>> GetEstimatedValues(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler) { 72 if (!solutions.Any()) yield break; 73 var estimatedValuesEnumerators = (from solution in solutions 74 select new { Solution = solution, EstimatedValuesEnumerator = solution.Model.GetEstimatedClassValues(dataset, rows).GetEnumerator() }) 75 .ToList(); 79 76 80 while (estimatedValuesEnumerators.All(en => en.MoveNext())) { 81 yield return from enumerator in estimatedValuesEnumerators 82 select enumerator.Current; 77 var rowEnumerator = rows.GetEnumerator(); 78 while (rowEnumerator.MoveNext() && estimatedValuesEnumerators.All(x => x.EstimatedValuesEnumerator.MoveNext())) { 79 yield return (from enumerator in estimatedValuesEnumerators 80 where handler(enumerator.Solution.ProblemData, rowEnumerator.Current) 81 select enumerator) 82 .ToDictionary(x => (IClassificationSolution)x.Solution, x => x.EstimatedValuesEnumerator.Current); 83 83 } 84 84 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MajorityVoteWeightCalculator.cs
r7531 r7549 48 48 } 49 49 50 protected override IEnumerable<double> CalculateWeights(I temCollection<IClassificationSolution> classificationSolutions) {51 return Enumerable.Repeat<double>(1, classificationSolutions.Count );50 protected override IEnumerable<double> CalculateWeights(IEnumerable<IClassificationSolution> classificationSolutions) { 51 return Enumerable.Repeat<double>(1, classificationSolutions.Count()); 52 52 } 53 53 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MedianThresholdCalculator.cs
r7531 r7549 46 46 protected double[] classValues; 47 47 48 protected override IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions) { 49 List<List<double>> estimatedTrainingValEnumerators = new List<List<double>>(); 50 List<List<double>> estimatedTrainingClassValEnumerators = new List<List<double>>(); 48 /// <summary> 49 /// 50 /// </summary> 51 /// <param name="discriminantSolutions"></param> 52 /// <returns>median instead of weights, because it doesn't use any weights</returns> 53 protected override IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions) { 54 List<List<double>> estimatedValues = new List<List<double>>(); 55 List<List<double>> estimatedClassValues = new List<List<double>>(); 56 57 List<IClassificationProblemData> solutionProblemData = discriminantSolutions.Select(sol => sol.ProblemData).ToList(); 58 Dataset dataSet = solutionProblemData[0].Dataset; 59 IEnumerable<int> rows = Enumerable.Range(0, dataSet.Rows); 51 60 foreach (var solution in discriminantSolutions) { 52 estimated TrainingValEnumerators.Add(solution.EstimatedTrainingValues.ToList());53 estimated TrainingClassValEnumerators.Add(solution.EstimatedTrainingClassValues.ToList());61 estimatedValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList()); 62 estimatedClassValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList()); 54 63 } 55 64 56 65 List<double> median = new List<double>(); 57 58 IClassificationProblemData problemData = discriminantSolutions.ElementAt(0).ProblemData; 59 List<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(); 60 IEnumerable<double> trainingVal = GetValues(targetValues, problemData.TrainingIndizes); 61 62 for (int i = 0; i < estimatedTrainingClassValEnumerators.First().Count; i++) { 63 var points = (from solution in estimatedTrainingValEnumerators 64 select solution[i]) 65 .OrderBy(p => p) 66 .ToList(); 67 68 median.Add(GetMedian(points)); 66 List<double> targetValues = dataSet.GetDoubleValues(solutionProblemData[0].TargetVariable).ToList(); 67 IList<double> curTrainingpoints = new List<double>(); 68 int removed = 0; 69 int count = targetValues.Count; 70 for (int point = 0; point < count; point++) { 71 curTrainingpoints.Clear(); 72 for (int solutionPos = 0; solutionPos < solutionProblemData.Count; solutionPos++) { 73 if (PointInTraining(solutionProblemData[solutionPos], point)) { 74 curTrainingpoints.Add(estimatedValues[solutionPos][point]); 75 } 76 } 77 if (curTrainingpoints.Count > 0) 78 median.Add(GetMedian(curTrainingpoints.OrderBy(p => p).ToList())); 79 else { 80 //remove not used points 81 targetValues.RemoveAt(point - removed); 82 removed++; 83 } 69 84 } 70 AccuracyMaximizationThresholdCalculator.CalculateThresholds( problemData, median, trainingVal, out classValues, out threshold);85 AccuracyMaximizationThresholdCalculator.CalculateThresholds(solutionProblemData[0], median, targetValues, out classValues, out threshold); 71 86 return median; 72 87 } 73 88 74 protected override double DiscriminantAggregateEstimatedClassValues(I Enumerable<double> estimatedClassValues, IEnumerable<double> estimatedValues) {89 protected override double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IClassificationSolution, double> estimatedValues) { 75 90 double classValue = classValues.First(); 76 double median = GetMedian(estimatedValues.ToList()); 91 IList<double> values = estimatedValues.Select(x => x.Value).ToList(); 92 if (values.Count <= 0) 93 return double.NaN; 94 double median = GetMedian(values); 77 95 for (int i = 0; i < classValues.Count(); i++) { 78 96 if (median > threshold[i]) … … 87 105 int count = estimatedValues.Count; 88 106 if (count % 2 == 0) 89 return 0.5 * (estimatedValues[count / 2 ] + estimatedValues[count / 2 + 1]);107 return 0.5 * (estimatedValues[count / 2 - 1] + estimatedValues[count / 2]); 90 108 else 91 return estimatedValues[ (count + 1)/ 2];109 return estimatedValues[count / 2]; 92 110 } 93 111 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/NeighbourhoodWeightCalculator.cs
r7531 r7549 49 49 } 50 50 51 protected override IEnumerable<double> DiscriminantCalculateWeights(ItemCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions) { 52 List<List<double>> estimatedTrainingValEnumerators = new List<List<double>>(); 53 List<List<double>> estimatedTrainingClassValEnumerators = new List<List<double>>(); 51 protected override IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions) { 52 List<List<double>> estimatedValues = new List<List<double>>(); 53 List<List<double>> estimatedClassValues = new List<List<double>>(); 54 55 List<IClassificationProblemData> solutionProblemData = discriminantSolutions.Select(sol => sol.ProblemData).ToList(); 56 Dataset dataSet = solutionProblemData[0].Dataset; 57 IEnumerable<int> rows = Enumerable.Range(0, dataSet.Rows); 54 58 foreach (var solution in discriminantSolutions) { 55 estimated TrainingValEnumerators.Add(solution.EstimatedTrainingValues.ToList());56 estimated TrainingClassValEnumerators.Add(solution.EstimatedTrainingClassValues.ToList());59 estimatedValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList()); 60 estimatedClassValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList()); 57 61 } 58 62 59 List<double> weights = Enumerable.Repeat<double>(0, discriminantSolutions.Count()).ToList<double>(); 60 61 IClassificationProblemData problemData = discriminantSolutions.ElementAt(0).ProblemData; 62 List<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(); 63 List<double> trainingVal = GetValues(targetValues, problemData.TrainingIndizes).ToList(); 63 List<double> weights = Enumerable.Repeat<double>(0, solutionProblemData.Count).ToList<double>(); 64 List<double> targetValues = dataSet.GetDoubleValues(solutionProblemData[0].TargetVariable).ToList(); 64 65 65 66 double pointAvg, help; 66 67 int count; 67 for (int point = 0; point < estimatedTrainingClassValEnumerators.First().Count; point++) {68 for (int point = 0; point < targetValues.Count; point++) { 68 69 pointAvg = 0.0; 69 70 count = 0; 70 for (int solution = 0; solution < estimatedTrainingClassValEnumerators.Count; solution++) { 71 if (estimatedTrainingClassValEnumerators[solution][point].Equals(targetValues[point])) { 72 pointAvg += estimatedTrainingValEnumerators[solution][point]; 71 for (int solutionPos = 0; solutionPos < estimatedClassValues.Count; solutionPos++) { 72 if (PointInTraining(solutionProblemData[solutionPos], point) 73 && estimatedClassValues[solutionPos][point].Equals(targetValues[point])) { 74 pointAvg += estimatedValues[solutionPos][point]; 73 75 count++; 74 76 } 75 77 } 76 78 pointAvg /= (double)count; 77 for (int solution = 0; solution < estimatedTrainingClassValEnumerators.Count; solution++) { 78 if (estimatedTrainingClassValEnumerators[solution][point].Equals(targetValues[point])) { 79 weights[solution] += 0.5; 80 help = Math.Abs(estimatedTrainingValEnumerators[solution][point] - 0.5); 81 weights[solution] += help < 0.5 ? 0.5 - help : 0.0; 79 for (int solutionPos = 0; solutionPos < estimatedClassValues.Count; solutionPos++) { 80 if (PointInTraining(solutionProblemData[solutionPos], point) 81 && estimatedClassValues[solutionPos][point].Equals(targetValues[point])) { 82 weights[solutionPos] += 0.5; 83 help = Math.Abs(estimatedValues[solutionPos][point] - 0.5); 84 weights[solutionPos] += help < 0.5 ? 0.5 - help : 0.0; 82 85 } 83 86 } 87 } 88 // normalize the weight (otherwise a model with a bigger training partition would probably be better) 89 for (int i = 0; i < weights.Count; i++) { 90 weights[i] = weights[i] / solutionProblemData[i].TrainingIndizes.Count(); 84 91 } 85 92 return weights; -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/PointCertaintyWeightCalculator.cs
r7531 r7549 44 44 } 45 45 46 protected override IEnumerable<double> DiscriminantCalculateWeights(I temCollection<IDiscriminantFunctionClassificationSolution> discriminantSolutions) {46 protected override IEnumerable<double> DiscriminantCalculateWeights(IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions) { 47 47 List<double> weights = new List<double>(); 48 48 IClassificationProblemData problemData = discriminantSolutions.ElementAt(0).ProblemData; 49 IEnumerable<double> targetValues = GetValues(problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(), problemData.TrainingIndizes); 50 IEnumerator<double> trainingValues; 49 // class Values are the same in all problem data sets 51 50 double avg = problemData.ClassValues.Average(); 52 51 52 IEnumerable<double> targetValues; 53 IEnumerator<double> trainingValues; 54 53 55 foreach (var solution in discriminantSolutions) { 56 problemData = solution.ProblemData; 57 targetValues = GetValues(problemData.Dataset.GetDoubleValues(problemData.TargetVariable).ToList(), problemData.TrainingIndizes); 58 trainingValues = targetValues.GetEnumerator(); 59 54 60 IEnumerator<double> estimatedTrainingVal = solution.EstimatedTrainingValues.GetEnumerator(); 55 61 IEnumerator<double> estimatedTrainingClassVal = solution.EstimatedTrainingClassValues.GetEnumerator(); 56 62 57 trainingValues = targetValues.GetEnumerator();58 63 double curWeight = 0.0; 59 64 while (estimatedTrainingVal.MoveNext() && estimatedTrainingClassVal.MoveNext() && trainingValues.MoveNext()) { … … 67 72 } 68 73 } 69 weights.Add(curWeight); 74 // normalize the weight (otherwise a model with a bigger training partition would probably be better) 75 weights.Add(curWeight / targetValues.Count()); 70 76 } 71 77 return weights; -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationEnsembleSolution.cs
r7259 r7549 25 25 public interface IClassificationEnsembleSolution : IClassificationSolution { 26 26 new IClassificationEnsembleModel Model { get; } 27 I ItemCollection<IClassificationSolution> ClassificationSolutions { get; }27 ICheckedItemCollection<IClassificationSolution> ClassificationSolutions { get; } 28 28 IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows); 29 29 } -
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationEnsembleSolutionWeightCalculator.cs
r7531 r7549 24 24 25 25 namespace HeuristicLab.Problems.DataAnalysis.Interfaces.Classification { 26 public delegate bool CheckPoint(IClassificationProblemData problemData, int point); 26 27 public interface IClassificationEnsembleSolutionWeightCalculator : INamedItem { 27 void CalculateNormalizedWeights(ItemCollection<IClassificationSolution> classificationSolutions); 28 IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationModel> models, Dataset dataset, IEnumerable<int> rows); 28 void CalculateNormalizedWeights(IEnumerable<IClassificationSolution> classificationSolutions); 29 IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler); 30 CheckPoint GetTestClassDelegate(); 31 CheckPoint GetTrainingClassDelegate(); 32 CheckPoint GetAllClassDelegate(); 29 33 } 30 34 }
Note: See TracChangeset
for help on using the changeset viewer.