Changeset 6239 for trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
- Timestamp:
- 05/20/11 16:10:07 (14 years ago)
- Location:
- trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
- Files:
-
- 1 added
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleModel.cs
r5809 r6239 39 39 get { return new List<IClassificationModel>(models); } 40 40 } 41 41 42 [StorableConstructor] 42 43 protected ClassificationEnsembleModel(bool deserializing) : base(deserializing) { } -
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
r6184 r6239 25 25 using HeuristicLab.Core; 26 26 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 27 using HeuristicLab.Data; 28 using System; 27 29 28 30 namespace HeuristicLab.Problems.DataAnalysis { … … 33 35 [Item("Classification Ensemble Solution", "A classification solution that contains an ensemble of multiple classification models")] 34 36 // [Creatable("Data Analysis")] 35 public class ClassificationEnsembleSolution : NamedItem, IClassificationEnsembleSolution { 37 public class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution { 38 39 public new IClassificationEnsembleModel Model { 40 set { base.Model = value; } 41 get { return (IClassificationEnsembleModel)base.Model; } 42 } 36 43 37 44 [Storable] 38 private List<IClassificationModel> models; 39 public IEnumerable<IClassificationModel> Models { 40 get { return new List<IClassificationModel>(models); } 41 } 45 private Dictionary<IClassificationModel, IntRange> trainingPartitions; 46 [Storable] 47 private Dictionary<IClassificationModel, IntRange> testPartitions; 48 49 42 50 [StorableConstructor] 43 51 protected ClassificationEnsembleSolution(bool deserializing) : base(deserializing) { } 44 52 protected ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner) 45 53 : base(original, cloner) { 46 this.models = original.Models.Select(m => cloner.Clone(m)).ToList(); 54 trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 55 testPartitions = new Dictionary<IClassificationModel, IntRange>(); 56 foreach (var model in Model.Models) { 57 trainingPartitions[model] = (IntRange)ProblemData.TrainingPartition.Clone(); 58 testPartitions[model] = (IntRange)ProblemData.TestPartition.Clone(); 59 } 60 RecalculateResults(); 47 61 } 48 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models )49 : base( ) {62 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData) 63 : base(new ClassificationEnsembleModel(models), new ClassificationEnsembleProblemData(problemData)) { 50 64 this.name = ItemName; 51 65 this.description = ItemDescription; 52 this.models = new List<IClassificationModel>(models); 66 trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 67 testPartitions = new Dictionary<IClassificationModel, IntRange>(); 68 foreach (var model in models) { 69 trainingPartitions[model] = (IntRange)problemData.TrainingPartition.Clone(); 70 testPartitions[model] = (IntRange)problemData.TestPartition.Clone(); 71 } 72 RecalculateResults(); 73 } 74 75 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions) 76 : base(new ClassificationEnsembleModel(models), new ClassificationEnsembleProblemData(problemData)) { 77 this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 78 this.testPartitions = new Dictionary<IClassificationModel, IntRange>(); 79 var modelEnumerator = models.GetEnumerator(); 80 var trainingPartitionEnumerator = trainingPartitions.GetEnumerator(); 81 var testPartitionEnumerator = testPartitions.GetEnumerator(); 82 while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) { 83 this.trainingPartitions[modelEnumerator.Current] = (IntRange)trainingPartitionEnumerator.Current.Clone(); 84 this.testPartitions[modelEnumerator.Current] = (IntRange)testPartitionEnumerator.Current.Clone(); 85 } 86 if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) { 87 throw new ArgumentException(); 88 } 89 RecalculateResults(); 53 90 } 54 91 … … 57 94 } 58 95 59 #region IClassificationEnsembleModel Members 96 public override IEnumerable<double> EstimatedTrainingClassValues { 97 get { 98 var rows = ProblemData.TrainingIndizes; 99 var estimatedValuesEnumerators = (from model in Model.Models 100 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 101 .ToList(); 102 var rowsEnumerator = rows.GetEnumerator(); 103 // aggregate to make sure that MoveNext is called for all enumerators 104 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 105 int currentRow = rowsEnumerator.Current; 106 107 var selectedEnumerators = from pair in estimatedValuesEnumerators 108 where trainingPartitions == null || !trainingPartitions.ContainsKey(pair.Model) || 109 (trainingPartitions[pair.Model].Start <= currentRow && currentRow < trainingPartitions[pair.Model].End) 110 select pair.EstimatedValuesEnumerator; 111 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current)); 112 } 113 } 114 } 115 116 public override IEnumerable<double> EstimatedTestClassValues { 117 get { 118 var rows = ProblemData.TestIndizes; 119 var estimatedValuesEnumerators = (from model in Model.Models 120 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 121 .ToList(); 122 var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator(); 123 // aggregate to make sure that MoveNext is called for all enumerators 124 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 125 int currentRow = rowsEnumerator.Current; 126 127 var selectedEnumerators = from pair in estimatedValuesEnumerators 128 where testPartitions == null || !testPartitions.ContainsKey(pair.Model) || 129 (testPartitions[pair.Model].Start <= currentRow && currentRow < testPartitions[pair.Model].End) 130 select pair.EstimatedValuesEnumerator; 131 132 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current)); 133 } 134 } 135 } 136 137 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 138 return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows) 139 select AggregateEstimatedClassValues(xs); 140 } 60 141 61 142 public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) { 62 var estimatedValuesEnumerators = (from model in models143 var estimatedValuesEnumerators = (from model in Model.Models 63 144 select model.GetEstimatedClassValues(dataset, rows).GetEnumerator()) 64 145 .ToList(); … … 70 151 } 71 152 72 #endregion 73 74 #region IClassificationModel Members 75 76 public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 77 foreach (var estimatedValuesVector in GetEstimatedClassValueVectors(dataset, rows)) { 78 // return the class which is most often occuring 79 yield return 80 estimatedValuesVector 81 .GroupBy(x => x) 82 .OrderBy(g => -g.Count()) 83 .Select(g => g.Key) 84 .First(); 85 } 153 private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) { 154 return estimatedClassValues 155 .GroupBy(x => x) 156 .OrderBy(g => -g.Count()) 157 .Select(g => g.Key) 158 .First(); 86 159 } 87 88 #endregion89 160 } 90 161 }
Note: See TracChangeset
for help on using the changeset viewer.