Changeset 6760 for branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
- Timestamp:
- 09/14/11 13:59:25 (13 years ago)
- Location:
- branches/PersistenceSpeedUp
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/PersistenceSpeedUp
- Property svn:ignore
-
old new 12 12 *.psess 13 13 *.vsp 14 *.docstates
-
- Property svn:mergeinfo changed
- Property svn:ignore
-
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis
- Property svn:mergeinfo changed
-
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
r6184 r6760 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; 25 using HeuristicLab.Collections; 24 26 using HeuristicLab.Common; 25 27 using HeuristicLab.Core; 28 using HeuristicLab.Data; 26 29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 27 30 … … 32 35 [StorableClass] 33 36 [Item("Classification Ensemble Solution", "A classification solution that contains an ensemble of multiple classification models")] 34 // [Creatable("Data Analysis")] 35 public class ClassificationEnsembleSolution : NamedItem, IClassificationEnsembleSolution { 37 [Creatable("Data Analysis - Ensembles")] 38 public sealed class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution { 39 public new IClassificationEnsembleModel Model { 40 get { return (IClassificationEnsembleModel)base.Model; } 41 } 42 public new ClassificationEnsembleProblemData ProblemData { 43 get { return (ClassificationEnsembleProblemData)base.ProblemData; } 44 set { base.ProblemData = value; } 45 } 46 47 private readonly ItemCollection<IClassificationSolution> classificationSolutions; 48 public IItemCollection<IClassificationSolution> ClassificationSolutions { 49 get { return classificationSolutions; } 50 } 36 51 37 52 [Storable] 38 private List<IClassificationModel> models;39 public IEnumerable<IClassificationModel> Models {40 get { return new List<IClassificationModel>(models); }41 } 53 private Dictionary<IClassificationModel, IntRange> trainingPartitions; 54 [Storable] 55 private Dictionary<IClassificationModel, IntRange> testPartitions; 56 42 57 [StorableConstructor] 43 protected ClassificationEnsembleSolution(bool deserializing) : base(deserializing) { } 44 protected ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner) 58 private ClassificationEnsembleSolution(bool deserializing) 59 : base(deserializing) { 60 classificationSolutions = new ItemCollection<IClassificationSolution>(); 61 } 62 [StorableHook(HookType.AfterDeserialization)] 63 private void AfterDeserialization() { 64 foreach (var model in Model.Models) { 65 IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone(); 66 problemData.TrainingPartition.Start = trainingPartitions[model].Start; 67 problemData.TrainingPartition.End = trainingPartitions[model].End; 68 problemData.TestPartition.Start = testPartitions[model].Start; 69 problemData.TestPartition.End = testPartitions[model].End; 70 71 classificationSolutions.Add(model.CreateClassificationSolution(problemData)); 72 } 73 RegisterClassificationSolutionsEventHandler(); 74 } 75 76 private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner) 45 77 : base(original, cloner) { 46 this.models = original.Models.Select(m => cloner.Clone(m)).ToList(); 47 } 48 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models) 49 : base() { 50 this.name = ItemName; 51 this.description = ItemDescription; 52 this.models = new List<IClassificationModel>(models); 78 trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 79 testPartitions = new Dictionary<IClassificationModel, IntRange>(); 80 foreach (var pair in original.trainingPartitions) { 81 trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value); 82 } 83 foreach (var pair in original.testPartitions) { 84 testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value); 85 } 86 87 classificationSolutions = cloner.Clone(original.classificationSolutions); 88 RegisterClassificationSolutionsEventHandler(); 89 } 90 91 public ClassificationEnsembleSolution() 92 : base(new ClassificationEnsembleModel(), ClassificationEnsembleProblemData.EmptyProblemData) { 93 trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 94 testPartitions = new Dictionary<IClassificationModel, IntRange>(); 95 classificationSolutions = new ItemCollection<IClassificationSolution>(); 96 97 RegisterClassificationSolutionsEventHandler(); 98 } 99 100 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData) 101 : this(models, problemData, 102 models.Select(m => (IntRange)problemData.TrainingPartition.Clone()), 103 models.Select(m => (IntRange)problemData.TestPartition.Clone()) 104 ) { } 105 106 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions) 107 : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) { 108 this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 109 this.testPartitions = new Dictionary<IClassificationModel, IntRange>(); 110 this.classificationSolutions = new ItemCollection<IClassificationSolution>(); 111 112 List<IClassificationSolution> solutions = new List<IClassificationSolution>(); 113 var modelEnumerator = models.GetEnumerator(); 114 var trainingPartitionEnumerator = trainingPartitions.GetEnumerator(); 115 var testPartitionEnumerator = testPartitions.GetEnumerator(); 116 117 while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) { 118 var p = (IClassificationProblemData)problemData.Clone(); 119 p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start; 120 p.TrainingPartition.End = trainingPartitionEnumerator.Current.End; 121 p.TestPartition.Start = testPartitionEnumerator.Current.Start; 122 p.TestPartition.End = testPartitionEnumerator.Current.End; 123 124 solutions.Add(modelEnumerator.Current.CreateClassificationSolution(p)); 125 } 126 if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) { 127 throw new ArgumentException(); 128 } 129 130 RegisterClassificationSolutionsEventHandler(); 131 classificationSolutions.AddRange(solutions); 53 132 } 54 133 … … 56 135 return new ClassificationEnsembleSolution(this, cloner); 57 136 } 58 59 #region IClassificationEnsembleModel Members 137 private void RegisterClassificationSolutionsEventHandler() { 138 classificationSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsAdded); 139 classificationSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsRemoved); 140 classificationSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CollectionReset); 141 } 142 143 protected override void RecalculateResults() { 144 CalculateResults(); 145 } 146 147 #region Evaluation 148 public override IEnumerable<double> EstimatedTrainingClassValues { 149 get { 150 var rows = ProblemData.TrainingIndizes; 151 var estimatedValuesEnumerators = (from model in Model.Models 152 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 153 .ToList(); 154 var rowsEnumerator = rows.GetEnumerator(); 155 // aggregate to make sure that MoveNext is called for all enumerators 156 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 157 int currentRow = rowsEnumerator.Current; 158 159 var selectedEnumerators = from pair in estimatedValuesEnumerators 160 where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model) 161 select pair.EstimatedValuesEnumerator; 162 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current)); 163 } 164 } 165 } 166 167 public override IEnumerable<double> EstimatedTestClassValues { 168 get { 169 var rows = ProblemData.TestIndizes; 170 var estimatedValuesEnumerators = (from model in Model.Models 171 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 172 .ToList(); 173 var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator(); 174 // aggregate to make sure that MoveNext is called for all enumerators 175 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 176 int currentRow = rowsEnumerator.Current; 177 178 var selectedEnumerators = from pair in estimatedValuesEnumerators 179 where RowIsTestForModel(currentRow, pair.Model) 180 select pair.EstimatedValuesEnumerator; 181 182 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current)); 183 } 184 } 185 } 186 187 private bool RowIsTrainingForModel(int currentRow, IClassificationModel model) { 188 return trainingPartitions == null || !trainingPartitions.ContainsKey(model) || 189 (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End); 190 } 191 192 private bool RowIsTestForModel(int currentRow, IClassificationModel model) { 193 return testPartitions == null || !testPartitions.ContainsKey(model) || 194 (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End); 195 } 196 197 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 198 return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows) 199 select AggregateEstimatedClassValues(xs); 200 } 60 201 61 202 public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) { 62 var estimatedValuesEnumerators = (from model in models203 var estimatedValuesEnumerators = (from model in Model.Models 63 204 select model.GetEstimatedClassValues(dataset, rows).GetEnumerator()) 64 205 .ToList(); … … 70 211 } 71 212 213 private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) { 214 return estimatedClassValues 215 .GroupBy(x => x) 216 .OrderBy(g => -g.Count()) 217 .Select(g => g.Key) 218 .DefaultIfEmpty(double.NaN) 219 .First(); 220 } 72 221 #endregion 73 222 74 #region IClassificationModel Members 75 76 public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 77 foreach (var estimatedValuesVector in GetEstimatedClassValueVectors(dataset, rows)) { 78 // return the class which is most often occuring 79 yield return 80 estimatedValuesVector 81 .GroupBy(x => x) 82 .OrderBy(g => -g.Count()) 83 .Select(g => g.Key) 84 .First(); 85 } 86 } 87 88 #endregion 223 protected override void OnProblemDataChanged() { 224 IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset, 225 ProblemData.AllowedInputVariables, 226 ProblemData.TargetVariable); 227 problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start; 228 problemData.TrainingPartition.End = ProblemData.TrainingPartition.End; 229 problemData.TestPartition.Start = ProblemData.TestPartition.Start; 230 problemData.TestPartition.End = ProblemData.TestPartition.End; 231 232 foreach (var solution in ClassificationSolutions) { 233 if (solution is ClassificationEnsembleSolution) 234 solution.ProblemData = ProblemData; 235 else 236 solution.ProblemData = problemData; 237 } 238 foreach (var trainingPartition in trainingPartitions.Values) { 239 trainingPartition.Start = ProblemData.TrainingPartition.Start; 240 trainingPartition.End = ProblemData.TrainingPartition.End; 241 } 242 foreach (var testPartition in testPartitions.Values) { 243 testPartition.Start = ProblemData.TestPartition.Start; 244 testPartition.End = ProblemData.TestPartition.End; 245 } 246 247 base.OnProblemDataChanged(); 248 } 249 250 public void AddClassificationSolutions(IEnumerable<IClassificationSolution> solutions) { 251 classificationSolutions.AddRange(solutions); 252 } 253 public void RemoveClassificationSolutions(IEnumerable<IClassificationSolution> solutions) { 254 classificationSolutions.RemoveRange(solutions); 255 } 256 257 private void classificationSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) { 258 foreach (var solution in e.Items) AddClassificationSolution(solution); 259 RecalculateResults(); 260 } 261 private void classificationSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) { 262 foreach (var solution in e.Items) RemoveClassificationSolution(solution); 263 RecalculateResults(); 264 } 265 private void classificationSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) { 266 foreach (var solution in e.OldItems) RemoveClassificationSolution(solution); 267 foreach (var solution in e.Items) AddClassificationSolution(solution); 268 RecalculateResults(); 269 } 270 271 private void AddClassificationSolution(IClassificationSolution solution) { 272 if (Model.Models.Contains(solution.Model)) throw new ArgumentException(); 273 Model.Add(solution.Model); 274 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 275 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 276 } 277 278 private void RemoveClassificationSolution(IClassificationSolution solution) { 279 if (!Model.Models.Contains(solution.Model)) throw new ArgumentException(); 280 Model.Remove(solution.Model); 281 trainingPartitions.Remove(solution.Model); 282 testPartitions.Remove(solution.Model); 283 } 89 284 } 90 285 }
Note: See TracChangeset
for help on using the changeset viewer.