Changeset 6760 for branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
- Timestamp:
- 09/14/11 13:59:25 (13 years ago)
- Location:
- branches/PersistenceSpeedUp
- Files:
-
- 8 edited
- 3 copied
Legend:
- Unmodified
- Added
- Removed
-
branches/PersistenceSpeedUp
- Property svn:ignore
-
old new 12 12 *.psess 13 13 *.vsp 14 *.docstates
-
- Property svn:mergeinfo changed
- Property svn:ignore
-
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis
- Property svn:mergeinfo changed
-
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleModel.cs
r5809 r6760 39 39 get { return new List<IClassificationModel>(models); } 40 40 } 41 41 42 [StorableConstructor] 42 43 protected ClassificationEnsembleModel(bool deserializing) : base(deserializing) { } … … 45 46 this.models = original.Models.Select(m => cloner.Clone(m)).ToList(); 46 47 } 48 49 public ClassificationEnsembleModel() : this(Enumerable.Empty<IClassificationModel>()) { } 47 50 public ClassificationEnsembleModel(IEnumerable<IClassificationModel> models) 48 51 : base() { … … 57 60 58 61 #region IClassificationEnsembleModel Members 62 public void Add(IClassificationModel model) { 63 models.Add(model); 64 } 65 public void Remove(IClassificationModel model) { 66 models.Remove(model); 67 } 59 68 60 69 public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) { … … 85 94 } 86 95 96 IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) { 97 return new ClassificationEnsembleSolution(models, problemData); 98 } 87 99 #endregion 88 100 } -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
r6184 r6760 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; 25 using HeuristicLab.Collections; 24 26 using HeuristicLab.Common; 25 27 using HeuristicLab.Core; 28 using HeuristicLab.Data; 26 29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 27 30 … … 32 35 [StorableClass] 33 36 [Item("Classification Ensemble Solution", "A classification solution that contains an ensemble of multiple classification models")] 34 // [Creatable("Data Analysis")] 35 public class ClassificationEnsembleSolution : NamedItem, IClassificationEnsembleSolution { 37 [Creatable("Data Analysis - Ensembles")] 38 public sealed class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution { 39 public new IClassificationEnsembleModel Model { 40 get { return (IClassificationEnsembleModel)base.Model; } 41 } 42 public new ClassificationEnsembleProblemData ProblemData { 43 get { return (ClassificationEnsembleProblemData)base.ProblemData; } 44 set { base.ProblemData = value; } 45 } 46 47 private readonly ItemCollection<IClassificationSolution> classificationSolutions; 48 public IItemCollection<IClassificationSolution> ClassificationSolutions { 49 get { return classificationSolutions; } 50 } 36 51 37 52 [Storable] 38 private List<IClassificationModel> models;39 public IEnumerable<IClassificationModel> Models {40 get { return new List<IClassificationModel>(models); }41 } 53 private Dictionary<IClassificationModel, IntRange> trainingPartitions; 54 [Storable] 55 private Dictionary<IClassificationModel, IntRange> testPartitions; 56 42 57 [StorableConstructor] 43 protected ClassificationEnsembleSolution(bool deserializing) : base(deserializing) { } 44 protected ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner) 58 private ClassificationEnsembleSolution(bool deserializing) 59 : base(deserializing) { 60 classificationSolutions = new ItemCollection<IClassificationSolution>(); 61 } 62 [StorableHook(HookType.AfterDeserialization)] 63 private void AfterDeserialization() { 64 foreach (var model in Model.Models) { 65 IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone(); 66 problemData.TrainingPartition.Start = trainingPartitions[model].Start; 67 problemData.TrainingPartition.End = trainingPartitions[model].End; 68 problemData.TestPartition.Start = testPartitions[model].Start; 69 problemData.TestPartition.End = testPartitions[model].End; 70 71 classificationSolutions.Add(model.CreateClassificationSolution(problemData)); 72 } 73 RegisterClassificationSolutionsEventHandler(); 74 } 75 76 private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner) 45 77 : base(original, cloner) { 46 this.models = original.Models.Select(m => cloner.Clone(m)).ToList(); 47 } 48 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models) 49 : base() { 50 this.name = ItemName; 51 this.description = ItemDescription; 52 this.models = new List<IClassificationModel>(models); 78 trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 79 testPartitions = new Dictionary<IClassificationModel, IntRange>(); 80 foreach (var pair in original.trainingPartitions) { 81 trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value); 82 } 83 foreach (var pair in original.testPartitions) { 84 testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value); 85 } 86 87 classificationSolutions = cloner.Clone(original.classificationSolutions); 88 RegisterClassificationSolutionsEventHandler(); 89 } 90 91 public ClassificationEnsembleSolution() 92 : base(new ClassificationEnsembleModel(), ClassificationEnsembleProblemData.EmptyProblemData) { 93 trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 94 testPartitions = new Dictionary<IClassificationModel, IntRange>(); 95 classificationSolutions = new ItemCollection<IClassificationSolution>(); 96 97 RegisterClassificationSolutionsEventHandler(); 98 } 99 100 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData) 101 : this(models, problemData, 102 models.Select(m => (IntRange)problemData.TrainingPartition.Clone()), 103 models.Select(m => (IntRange)problemData.TestPartition.Clone()) 104 ) { } 105 106 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions) 107 : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) { 108 this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 109 this.testPartitions = new Dictionary<IClassificationModel, IntRange>(); 110 this.classificationSolutions = new ItemCollection<IClassificationSolution>(); 111 112 List<IClassificationSolution> solutions = new List<IClassificationSolution>(); 113 var modelEnumerator = models.GetEnumerator(); 114 var trainingPartitionEnumerator = trainingPartitions.GetEnumerator(); 115 var testPartitionEnumerator = testPartitions.GetEnumerator(); 116 117 while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) { 118 var p = (IClassificationProblemData)problemData.Clone(); 119 p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start; 120 p.TrainingPartition.End = trainingPartitionEnumerator.Current.End; 121 p.TestPartition.Start = testPartitionEnumerator.Current.Start; 122 p.TestPartition.End = testPartitionEnumerator.Current.End; 123 124 solutions.Add(modelEnumerator.Current.CreateClassificationSolution(p)); 125 } 126 if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) { 127 throw new ArgumentException(); 128 } 129 130 RegisterClassificationSolutionsEventHandler(); 131 classificationSolutions.AddRange(solutions); 53 132 } 54 133 … … 56 135 return new ClassificationEnsembleSolution(this, cloner); 57 136 } 58 59 #region IClassificationEnsembleModel Members 137 private void RegisterClassificationSolutionsEventHandler() { 138 classificationSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsAdded); 139 classificationSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsRemoved); 140 classificationSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CollectionReset); 141 } 142 143 protected override void RecalculateResults() { 144 CalculateResults(); 145 } 146 147 #region Evaluation 148 public override IEnumerable<double> EstimatedTrainingClassValues { 149 get { 150 var rows = ProblemData.TrainingIndizes; 151 var estimatedValuesEnumerators = (from model in Model.Models 152 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 153 .ToList(); 154 var rowsEnumerator = rows.GetEnumerator(); 155 // aggregate to make sure that MoveNext is called for all enumerators 156 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 157 int currentRow = rowsEnumerator.Current; 158 159 var selectedEnumerators = from pair in estimatedValuesEnumerators 160 where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model) 161 select pair.EstimatedValuesEnumerator; 162 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current)); 163 } 164 } 165 } 166 167 public override IEnumerable<double> EstimatedTestClassValues { 168 get { 169 var rows = ProblemData.TestIndizes; 170 var estimatedValuesEnumerators = (from model in Model.Models 171 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 172 .ToList(); 173 var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator(); 174 // aggregate to make sure that MoveNext is called for all enumerators 175 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 176 int currentRow = rowsEnumerator.Current; 177 178 var selectedEnumerators = from pair in estimatedValuesEnumerators 179 where RowIsTestForModel(currentRow, pair.Model) 180 select pair.EstimatedValuesEnumerator; 181 182 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current)); 183 } 184 } 185 } 186 187 private bool RowIsTrainingForModel(int currentRow, IClassificationModel model) { 188 return trainingPartitions == null || !trainingPartitions.ContainsKey(model) || 189 (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End); 190 } 191 192 private bool RowIsTestForModel(int currentRow, IClassificationModel model) { 193 return testPartitions == null || !testPartitions.ContainsKey(model) || 194 (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End); 195 } 196 197 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 198 return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows) 199 select AggregateEstimatedClassValues(xs); 200 } 60 201 61 202 public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) { 62 var estimatedValuesEnumerators = (from model in models203 var estimatedValuesEnumerators = (from model in Model.Models 63 204 select model.GetEstimatedClassValues(dataset, rows).GetEnumerator()) 64 205 .ToList(); … … 70 211 } 71 212 213 private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) { 214 return estimatedClassValues 215 .GroupBy(x => x) 216 .OrderBy(g => -g.Count()) 217 .Select(g => g.Key) 218 .DefaultIfEmpty(double.NaN) 219 .First(); 220 } 72 221 #endregion 73 222 74 #region IClassificationModel Members 75 76 public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 77 foreach (var estimatedValuesVector in GetEstimatedClassValueVectors(dataset, rows)) { 78 // return the class which is most often occuring 79 yield return 80 estimatedValuesVector 81 .GroupBy(x => x) 82 .OrderBy(g => -g.Count()) 83 .Select(g => g.Key) 84 .First(); 85 } 86 } 87 88 #endregion 223 protected override void OnProblemDataChanged() { 224 IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset, 225 ProblemData.AllowedInputVariables, 226 ProblemData.TargetVariable); 227 problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start; 228 problemData.TrainingPartition.End = ProblemData.TrainingPartition.End; 229 problemData.TestPartition.Start = ProblemData.TestPartition.Start; 230 problemData.TestPartition.End = ProblemData.TestPartition.End; 231 232 foreach (var solution in ClassificationSolutions) { 233 if (solution is ClassificationEnsembleSolution) 234 solution.ProblemData = ProblemData; 235 else 236 solution.ProblemData = problemData; 237 } 238 foreach (var trainingPartition in trainingPartitions.Values) { 239 trainingPartition.Start = ProblemData.TrainingPartition.Start; 240 trainingPartition.End = ProblemData.TrainingPartition.End; 241 } 242 foreach (var testPartition in testPartitions.Values) { 243 testPartition.Start = ProblemData.TestPartition.Start; 244 testPartition.End = ProblemData.TestPartition.End; 245 } 246 247 base.OnProblemDataChanged(); 248 } 249 250 public void AddClassificationSolutions(IEnumerable<IClassificationSolution> solutions) { 251 classificationSolutions.AddRange(solutions); 252 } 253 public void RemoveClassificationSolutions(IEnumerable<IClassificationSolution> solutions) { 254 classificationSolutions.RemoveRange(solutions); 255 } 256 257 private void classificationSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) { 258 foreach (var solution in e.Items) AddClassificationSolution(solution); 259 RecalculateResults(); 260 } 261 private void classificationSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) { 262 foreach (var solution in e.Items) RemoveClassificationSolution(solution); 263 RecalculateResults(); 264 } 265 private void classificationSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) { 266 foreach (var solution in e.OldItems) RemoveClassificationSolution(solution); 267 foreach (var solution in e.Items) AddClassificationSolution(solution); 268 RecalculateResults(); 269 } 270 271 private void AddClassificationSolution(IClassificationSolution solution) { 272 if (Model.Models.Contains(solution.Model)) throw new ArgumentException(); 273 Model.Add(solution.Model); 274 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 275 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 276 } 277 278 private void RemoveClassificationSolution(IClassificationSolution solution) { 279 if (!Model.Models.Contains(solution.Model)) throw new ArgumentException(); 280 Model.Remove(solution.Model); 281 trainingPartitions.Remove(solution.Model); 282 testPartitions.Remove(solution.Model); 283 } 89 284 } 90 285 } -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationProblemData.cs
r6232 r6760 34 34 [Item("ClassificationProblemData", "Represents an item containing all data defining a classification problem.")] 35 35 public class ClassificationProblemData : DataAnalysisProblemData, IClassificationProblemData { 36 pr ivateconst string TargetVariableParameterName = "TargetVariable";37 pr ivateconst string ClassNamesParameterName = "ClassNames";38 pr ivateconst string ClassificationPenaltiesParameterName = "ClassificationPenalties";39 pr ivateconst int MaximumNumberOfClasses = 20;40 pr ivateconst int InspectedRowsToDetermineTargets = 500;36 protected const string TargetVariableParameterName = "TargetVariable"; 37 protected const string ClassNamesParameterName = "ClassNames"; 38 protected const string ClassificationPenaltiesParameterName = "ClassificationPenalties"; 39 protected const int MaximumNumberOfClasses = 20; 40 protected const int InspectedRowsToDetermineTargets = 500; 41 41 42 42 #region default data … … 171 171 {1176881,7,5,3,7,4,10,7,5,5,4 } 172 172 }; 173 private static Dataset defaultDataset; 174 private static IEnumerable<string> defaultAllowedInputVariables; 175 private static string defaultTargetVariable; 173 private static readonly Dataset defaultDataset; 174 private static readonly IEnumerable<string> defaultAllowedInputVariables; 175 private static readonly string defaultTargetVariable; 176 177 private static readonly ClassificationProblemData emptyProblemData; 178 public static ClassificationProblemData EmptyProblemData { 179 get { return EmptyProblemData; } 180 } 181 176 182 static ClassificationProblemData() { 177 183 defaultDataset = new Dataset(defaultVariableNames, defaultData); … … 181 187 defaultAllowedInputVariables = defaultVariableNames.Except(new List<string>() { "sample", "class" }); 182 188 defaultTargetVariable = "class"; 189 190 var problemData = new ClassificationProblemData(); 191 problemData.Parameters.Clear(); 192 problemData.Name = "Empty Classification ProblemData"; 193 problemData.Description = "This ProblemData acts as place holder before the correct problem data is loaded."; 194 problemData.isEmpty = true; 195 196 problemData.Parameters.Add(new FixedValueParameter<Dataset>(DatasetParameterName, "", new Dataset())); 197 problemData.Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemList<StringValue>>(InputVariablesParameterName, "")); 198 problemData.Parameters.Add(new FixedValueParameter<IntRange>(TrainingPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly())); 199 problemData.Parameters.Add(new FixedValueParameter<IntRange>(TestPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly())); 200 problemData.Parameters.Add(new ConstrainedValueParameter<StringValue>(TargetVariableParameterName, new ItemSet<StringValue>())); 201 problemData.Parameters.Add(new FixedValueParameter<StringMatrix>(ClassNamesParameterName, "", new StringMatrix(0, 0).AsReadOnly())); 202 problemData.Parameters.Add(new FixedValueParameter<DoubleMatrix>(ClassificationPenaltiesParameterName, "", (DoubleMatrix)new DoubleMatrix(0, 0).AsReadOnly())); 203 emptyProblemData = problemData; 183 204 } 184 205 #endregion 185 206 186 207 #region parameter properties 187 public IValueParameter<StringValue> TargetVariableParameter {188 get { return ( IValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }208 public ConstrainedValueParameter<StringValue> TargetVariableParameter { 209 get { return (ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; } 189 210 } 190 211 public IFixedValueParameter<StringMatrix> ClassNamesParameter { … … 205 226 get { 206 227 if (classValues == null) { 207 classValues = Dataset.Get EnumeratedVariableValues(TargetVariableParameter.Value.Value).Distinct().ToList();228 classValues = Dataset.GetDoubleValues(TargetVariableParameter.Value.Value).Distinct().ToList(); 208 229 classValues.Sort(); 209 230 } … … 249 270 RegisterParameterEvents(); 250 271 } 251 public override IDeepCloneable Clone(Cloner cloner) { return new ClassificationProblemData(this, cloner); } 272 public override IDeepCloneable Clone(Cloner cloner) { 273 if (this == emptyProblemData) return emptyProblemData; 274 return new ClassificationProblemData(this, cloner); 275 } 252 276 253 277 public ClassificationProblemData() : this(defaultDataset, defaultAllowedInputVariables, defaultTargetVariable) { } … … 267 291 private static IEnumerable<string> CheckVariablesForPossibleTargetVariables(Dataset dataset) { 268 292 int maxSamples = Math.Min(InspectedRowsToDetermineTargets, dataset.Rows); 269 var validTargetVariables = from v in dataset.VariableNames270 let DistinctValues = dataset.GetVariableValues(v)271 .Take(maxSamples)272 .Distinct()273 .Count()274 where DistinctValues < MaximumNumberOfClasses275 select v;293 var validTargetVariables = (from v in dataset.DoubleVariables 294 let distinctValues = dataset.GetDoubleValues(v) 295 .Take(maxSamples) 296 .Distinct() 297 .Count() 298 where distinctValues < MaximumNumberOfClasses 299 select v).ToArray(); 276 300 277 301 if (!validTargetVariables.Any()) … … 283 307 284 308 private void ResetTargetVariableDependentMembers() { 285 Der gisterParameterEvents();309 DeregisterParameterEvents(); 286 310 287 311 classNames = null; … … 357 381 ClassificationPenaltiesParameter.Value.ItemChanged += new EventHandler<EventArgs<int, int>>(MatrixParameter_ItemChanged); 358 382 } 359 private void Der gisterParameterEvents() {383 private void DeregisterParameterEvents() { 360 384 TargetVariableParameter.ValueChanged -= new EventHandler(TargetVariableParameter_ValueChanged); 361 385 ClassNamesParameter.Value.Reset -= new EventHandler(Parameter_ValueChanged); … … 386 410 dataset.Name = Path.GetFileName(fileName); 387 411 388 ClassificationProblemData problemData = new ClassificationProblemData(dataset, dataset. VariableNames.Skip(1), dataset.VariableNames.First());412 ClassificationProblemData problemData = new ClassificationProblemData(dataset, dataset.DoubleVariables.Skip(1), dataset.DoubleVariables.First()); 389 413 problemData.Name = "Data imported from " + Path.GetFileName(fileName); 390 414 return problemData; -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs
r6184 r6760 20 20 #endregion 21 21 22 using System;23 22 using System.Collections.Generic; 24 23 using System.Linq; 25 24 using HeuristicLab.Common; 26 using HeuristicLab.Data;27 using HeuristicLab.Optimization;28 25 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 29 26 … … 33 30 /// </summary> 34 31 [StorableClass] 35 public class ClassificationSolution : DataAnalysisSolution, IClassificationSolution { 36 private const string TrainingAccuracyResultName = "Accuracy (training)"; 37 private const string TestAccuracyResultName = "Accuracy (test)"; 38 39 public new IClassificationModel Model { 40 get { return (IClassificationModel)base.Model; } 41 protected set { base.Model = value; } 42 } 43 44 public new IClassificationProblemData ProblemData { 45 get { return (IClassificationProblemData)base.ProblemData; } 46 protected set { base.ProblemData = value; } 47 } 48 49 public double TrainingAccuracy { 50 get { return ((DoubleValue)this[TrainingAccuracyResultName].Value).Value; } 51 private set { ((DoubleValue)this[TrainingAccuracyResultName].Value).Value = value; } 52 } 53 54 public double TestAccuracy { 55 get { return ((DoubleValue)this[TestAccuracyResultName].Value).Value; } 56 private set { ((DoubleValue)this[TestAccuracyResultName].Value).Value = value; } 57 } 32 public abstract class ClassificationSolution : ClassificationSolutionBase { 33 protected readonly Dictionary<int, double> evaluationCache; 58 34 59 35 [StorableConstructor] 60 protected ClassificationSolution(bool deserializing) : base(deserializing) { } 36 protected ClassificationSolution(bool deserializing) 37 : base(deserializing) { 38 evaluationCache = new Dictionary<int, double>(); 39 } 61 40 protected ClassificationSolution(ClassificationSolution original, Cloner cloner) 62 41 : base(original, cloner) { 42 evaluationCache = new Dictionary<int, double>(original.evaluationCache); 63 43 } 64 44 public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData) 65 45 : base(model, problemData) { 66 Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue())); 67 Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue())); 68 RecalculateResults(); 46 evaluationCache = new Dictionary<int, double>(); 69 47 } 70 48 71 public override IDeepCloneable Clone(Cloner cloner) { 72 return new ClassificationSolution(this, cloner); 49 public override IEnumerable<double> EstimatedClassValues { 50 get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 51 } 52 public override IEnumerable<double> EstimatedTrainingClassValues { 53 get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); } 54 } 55 public override IEnumerable<double> EstimatedTestClassValues { 56 get { return GetEstimatedClassValues(ProblemData.TestIndizes); } 73 57 } 74 58 75 protected override void OnProblemDataChanged(EventArgs e) { 76 base.OnProblemDataChanged(e); 77 RecalculateResults(); 59 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 60 var rowsToEvaluate = rows.Except(evaluationCache.Keys); 61 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 62 var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 63 64 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 65 evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 66 } 67 68 return rows.Select(row => evaluationCache[row]); 78 69 } 79 70 80 protected override void On ModelChanged(EventArgs e) {81 base.OnModelChanged(e);82 RecalculateResults();71 protected override void OnProblemDataChanged() { 72 evaluationCache.Clear(); 73 base.OnProblemDataChanged(); 83 74 } 84 75 85 protected void RecalculateResults() { 86 double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values 87 IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 88 double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values 89 IEnumerable<double> originalTestClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 90 91 OnlineCalculatorError errorState; 92 double trainingAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTrainingClassValues, originalTrainingClassValues, out errorState); 93 if (errorState != OnlineCalculatorError.None) trainingAccuracy = double.NaN; 94 double testAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTestClassValues, originalTestClassValues, out errorState); 95 if (errorState != OnlineCalculatorError.None) testAccuracy = double.NaN; 96 97 TrainingAccuracy = trainingAccuracy; 98 TestAccuracy = testAccuracy; 99 } 100 101 public virtual IEnumerable<double> EstimatedClassValues { 102 get { 103 return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); 104 } 105 } 106 107 public virtual IEnumerable<double> EstimatedTrainingClassValues { 108 get { 109 return GetEstimatedClassValues(ProblemData.TrainingIndizes); 110 } 111 } 112 113 public virtual IEnumerable<double> EstimatedTestClassValues { 114 get { 115 return GetEstimatedClassValues(ProblemData.TestIndizes); 116 } 117 } 118 119 public virtual IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 120 return Model.GetEstimatedClassValues(ProblemData.Dataset, rows); 76 protected override void OnModelChanged() { 77 evaluationCache.Clear(); 78 base.OnModelChanged(); 121 79 } 122 80 } -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs
r5809 r6760 33 33 [StorableClass] 34 34 [Item("DiscriminantFunctionClassificationModel", "Represents a classification model that uses a discriminant function and classification thresholds.")] 35 public class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel {35 public abstract class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel { 36 36 [Storable] 37 37 private IRegressionModel model; … … 70 70 } 71 71 72 public override IDeepCloneable Clone(Cloner cloner) {73 return new DiscriminantFunctionClassificationModel(this, cloner);74 }75 76 72 public void SetThresholdsAndClassValues(IEnumerable<double> thresholds, IEnumerable<double> classValues) { 77 73 var classValuesArr = classValues.ToArray(); … … 106 102 } 107 103 #endregion 104 105 public abstract IDiscriminantFunctionClassificationSolution CreateDiscriminantFunctionClassificationSolution(IClassificationProblemData problemData); 106 public abstract IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData); 108 107 } 109 108 } -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs
r5942 r6760 20 20 #endregion 21 21 22 using System;23 22 using System.Collections.Generic; 24 23 using System.Linq; … … 26 25 using HeuristicLab.Core; 27 26 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 using HeuristicLab.Data;29 using HeuristicLab.Optimization;30 27 31 28 namespace HeuristicLab.Problems.DataAnalysis { … … 35 32 [StorableClass] 36 33 [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")] 37 public class DiscriminantFunctionClassificationSolution : ClassificationSolution, IDiscriminantFunctionClassificationSolution { 38 private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)"; 39 private const string TestMeanSquaredErrorResultName = "Mean squared error (test)"; 40 private const string TrainingRSquaredResultName = "Pearson's R² (training)"; 41 private const string TestRSquaredResultName = "Pearson's R² (test)"; 34 public abstract class DiscriminantFunctionClassificationSolution : DiscriminantFunctionClassificationSolutionBase { 35 protected readonly Dictionary<int, double> valueEvaluationCache; 36 protected readonly Dictionary<int, double> classValueEvaluationCache; 42 37 43 public new IDiscriminantFunctionClassificationModel Model { 44 get { return (IDiscriminantFunctionClassificationModel)base.Model; } 45 protected set { 46 if (value != null && value != Model) { 47 if (Model != null) { 48 Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged); 49 } 50 value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged); 51 base.Model = value; 52 } 53 } 38 [StorableConstructor] 39 protected DiscriminantFunctionClassificationSolution(bool deserializing) 40 : base(deserializing) { 41 valueEvaluationCache = new Dictionary<int, double>(); 42 classValueEvaluationCache = new Dictionary<int, double>(); 43 } 44 protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner) 45 : base(original, cloner) { 46 valueEvaluationCache = new Dictionary<int, double>(original.valueEvaluationCache); 47 classValueEvaluationCache = new Dictionary<int, double>(original.classValueEvaluationCache); 48 } 49 protected DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData) 50 : base(model, problemData) { 51 valueEvaluationCache = new Dictionary<int, double>(); 52 classValueEvaluationCache = new Dictionary<int, double>(); 53 54 SetAccuracyMaximizingThresholds(); 54 55 } 55 56 56 public double TrainingMeanSquaredError { 57 get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; } 58 private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; } 57 public override IEnumerable<double> EstimatedClassValues { 58 get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 59 } 60 public override IEnumerable<double> EstimatedTrainingClassValues { 61 get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); } 62 } 63 public override IEnumerable<double> EstimatedTestClassValues { 64 get { return GetEstimatedClassValues(ProblemData.TestIndizes); } 59 65 } 60 66 61 public double TestMeanSquaredError { 62 get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; } 63 private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; } 67 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 68 var rowsToEvaluate = rows.Except(classValueEvaluationCache.Keys); 69 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 70 var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 71 72 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 73 classValueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 74 } 75 76 return rows.Select(row => classValueEvaluationCache[row]); 64 77 } 65 78 66 public double TrainingRSquared {67 get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; }68 private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; }69 }70 79 71 public double TestRSquared { 72 get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; } 73 private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; } 74 } 75 76 [StorableConstructor] 77 protected DiscriminantFunctionClassificationSolution(bool deserializing) : base(deserializing) { } 78 protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner) 79 : base(original, cloner) { 80 RegisterEventHandler(); 81 } 82 public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData) 83 : this(new DiscriminantFunctionClassificationModel(model), problemData) { 84 } 85 public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData) 86 : base(model, problemData) { 87 Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue())); 88 Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue())); 89 Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue())); 90 Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue())); 91 RegisterEventHandler(); 92 SetAccuracyMaximizingThresholds(); 93 RecalculateResults(); 94 } 95 96 [StorableHook(HookType.AfterDeserialization)] 97 private void AfterDeserialization() { 98 RegisterEventHandler(); 99 } 100 101 protected new void RecalculateResults() { 102 double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values 103 IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 104 double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values 105 IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 106 107 OnlineCalculatorError errorState; 108 double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 109 TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN; 110 double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 111 TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN; 112 113 double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 114 TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN; 115 double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 116 TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN; 117 } 118 119 private void RegisterEventHandler() { 120 Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged); 121 } 122 private void Model_ThresholdsChanged(object sender, EventArgs e) { 123 OnModelThresholdsChanged(e); 124 } 125 126 public void SetAccuracyMaximizingThresholds() { 127 double[] classValues; 128 double[] thresholds; 129 var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 130 AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds); 131 132 Model.SetThresholdsAndClassValues(thresholds, classValues); 133 } 134 135 public void SetClassDistibutionCutPointThresholds() { 136 double[] classValues; 137 double[] thresholds; 138 var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 139 NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds); 140 141 Model.SetThresholdsAndClassValues(thresholds, classValues); 142 } 143 144 protected override void OnModelChanged(EventArgs e) { 145 base.OnModelChanged(e); 146 SetAccuracyMaximizingThresholds(); 147 RecalculateResults(); 148 } 149 150 protected override void OnProblemDataChanged(EventArgs e) { 151 base.OnProblemDataChanged(e); 152 SetAccuracyMaximizingThresholds(); 153 RecalculateResults(); 154 } 155 protected virtual void OnModelThresholdsChanged(EventArgs e) { 156 base.OnModelChanged(e); 157 RecalculateResults(); 158 } 159 160 public IEnumerable<double> EstimatedValues { 80 public override IEnumerable<double> EstimatedValues { 161 81 get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 162 82 } 163 164 public IEnumerable<double> EstimatedTrainingValues { 83 public override IEnumerable<double> EstimatedTrainingValues { 165 84 get { return GetEstimatedValues(ProblemData.TrainingIndizes); } 166 85 } 167 168 public IEnumerable<double> EstimatedTestValues { 86 public override IEnumerable<double> EstimatedTestValues { 169 87 get { return GetEstimatedValues(ProblemData.TestIndizes); } 170 88 } 171 89 172 public IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 173 return Model.GetEstimatedValues(ProblemData.Dataset, rows); 90 public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 91 var rowsToEvaluate = rows.Except(valueEvaluationCache.Keys); 92 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 93 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 94 95 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 96 valueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 97 } 98 99 return rows.Select(row => valueEvaluationCache[row]); 100 } 101 102 protected override void OnModelChanged() { 103 valueEvaluationCache.Clear(); 104 classValueEvaluationCache.Clear(); 105 base.OnModelChanged(); 106 } 107 protected override void OnModelThresholdsChanged(System.EventArgs e) { 108 classValueEvaluationCache.Clear(); 109 base.OnModelThresholdsChanged(e); 110 } 111 protected override void OnProblemDataChanged() { 112 valueEvaluationCache.Clear(); 113 classValueEvaluationCache.Clear(); 114 base.OnProblemDataChanged(); 174 115 } 175 116 }
Note: See TracChangeset
for help on using the changeset viewer.