Changeset 6760 for branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation
- Timestamp:
- 09/14/11 13:59:25 (13 years ago)
- Location:
- branches/PersistenceSpeedUp
- Files:
-
- 17 edited
- 5 copied
Legend:
- Unmodified
- Added
- Removed
-
branches/PersistenceSpeedUp
- Property svn:ignore
-
old new 12 12 *.psess 13 13 *.vsp 14 *.docstates
-
- Property svn:mergeinfo changed
- Property svn:ignore
-
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis
- Property svn:mergeinfo changed
-
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleModel.cs
r5809 r6760 39 39 get { return new List<IClassificationModel>(models); } 40 40 } 41 41 42 [StorableConstructor] 42 43 protected ClassificationEnsembleModel(bool deserializing) : base(deserializing) { } … … 45 46 this.models = original.Models.Select(m => cloner.Clone(m)).ToList(); 46 47 } 48 49 public ClassificationEnsembleModel() : this(Enumerable.Empty<IClassificationModel>()) { } 47 50 public ClassificationEnsembleModel(IEnumerable<IClassificationModel> models) 48 51 : base() { … … 57 60 58 61 #region IClassificationEnsembleModel Members 62 public void Add(IClassificationModel model) { 63 models.Add(model); 64 } 65 public void Remove(IClassificationModel model) { 66 models.Remove(model); 67 } 59 68 60 69 public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) { … … 85 94 } 86 95 96 IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) { 97 return new ClassificationEnsembleSolution(models, problemData); 98 } 87 99 #endregion 88 100 } -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
r6184 r6760 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; 25 using HeuristicLab.Collections; 24 26 using HeuristicLab.Common; 25 27 using HeuristicLab.Core; 28 using HeuristicLab.Data; 26 29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 27 30 … … 32 35 [StorableClass] 33 36 [Item("Classification Ensemble Solution", "A classification solution that contains an ensemble of multiple classification models")] 34 // [Creatable("Data Analysis")] 35 public class ClassificationEnsembleSolution : NamedItem, IClassificationEnsembleSolution { 37 [Creatable("Data Analysis - Ensembles")] 38 public sealed class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution { 39 public new IClassificationEnsembleModel Model { 40 get { return (IClassificationEnsembleModel)base.Model; } 41 } 42 public new ClassificationEnsembleProblemData ProblemData { 43 get { return (ClassificationEnsembleProblemData)base.ProblemData; } 44 set { base.ProblemData = value; } 45 } 46 47 private readonly ItemCollection<IClassificationSolution> classificationSolutions; 48 public IItemCollection<IClassificationSolution> ClassificationSolutions { 49 get { return classificationSolutions; } 50 } 36 51 37 52 [Storable] 38 private List<IClassificationModel> models;39 public IEnumerable<IClassificationModel> Models {40 get { return new List<IClassificationModel>(models); }41 } 53 private Dictionary<IClassificationModel, IntRange> trainingPartitions; 54 [Storable] 55 private Dictionary<IClassificationModel, IntRange> testPartitions; 56 42 57 [StorableConstructor] 43 protected ClassificationEnsembleSolution(bool deserializing) : base(deserializing) { } 44 protected ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner) 58 private ClassificationEnsembleSolution(bool deserializing) 59 : base(deserializing) { 60 classificationSolutions = new ItemCollection<IClassificationSolution>(); 61 } 62 [StorableHook(HookType.AfterDeserialization)] 63 private void AfterDeserialization() { 64 foreach (var model in Model.Models) { 65 IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone(); 66 problemData.TrainingPartition.Start = trainingPartitions[model].Start; 67 problemData.TrainingPartition.End = trainingPartitions[model].End; 68 problemData.TestPartition.Start = testPartitions[model].Start; 69 problemData.TestPartition.End = testPartitions[model].End; 70 71 classificationSolutions.Add(model.CreateClassificationSolution(problemData)); 72 } 73 RegisterClassificationSolutionsEventHandler(); 74 } 75 76 private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner) 45 77 : base(original, cloner) { 46 this.models = original.Models.Select(m => cloner.Clone(m)).ToList(); 47 } 48 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models) 49 : base() { 50 this.name = ItemName; 51 this.description = ItemDescription; 52 this.models = new List<IClassificationModel>(models); 78 trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 79 testPartitions = new Dictionary<IClassificationModel, IntRange>(); 80 foreach (var pair in original.trainingPartitions) { 81 trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value); 82 } 83 foreach (var pair in original.testPartitions) { 84 testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value); 85 } 86 87 classificationSolutions = cloner.Clone(original.classificationSolutions); 88 RegisterClassificationSolutionsEventHandler(); 89 } 90 91 public ClassificationEnsembleSolution() 92 : base(new ClassificationEnsembleModel(), ClassificationEnsembleProblemData.EmptyProblemData) { 93 trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 94 testPartitions = new Dictionary<IClassificationModel, IntRange>(); 95 classificationSolutions = new ItemCollection<IClassificationSolution>(); 96 97 RegisterClassificationSolutionsEventHandler(); 98 } 99 100 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData) 101 : this(models, problemData, 102 models.Select(m => (IntRange)problemData.TrainingPartition.Clone()), 103 models.Select(m => (IntRange)problemData.TestPartition.Clone()) 104 ) { } 105 106 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions) 107 : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) { 108 this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 109 this.testPartitions = new Dictionary<IClassificationModel, IntRange>(); 110 this.classificationSolutions = new ItemCollection<IClassificationSolution>(); 111 112 List<IClassificationSolution> solutions = new List<IClassificationSolution>(); 113 var modelEnumerator = models.GetEnumerator(); 114 var trainingPartitionEnumerator = trainingPartitions.GetEnumerator(); 115 var testPartitionEnumerator = testPartitions.GetEnumerator(); 116 117 while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) { 118 var p = (IClassificationProblemData)problemData.Clone(); 119 p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start; 120 p.TrainingPartition.End = trainingPartitionEnumerator.Current.End; 121 p.TestPartition.Start = testPartitionEnumerator.Current.Start; 122 p.TestPartition.End = testPartitionEnumerator.Current.End; 123 124 solutions.Add(modelEnumerator.Current.CreateClassificationSolution(p)); 125 } 126 if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) { 127 throw new ArgumentException(); 128 } 129 130 RegisterClassificationSolutionsEventHandler(); 131 classificationSolutions.AddRange(solutions); 53 132 } 54 133 … … 56 135 return new ClassificationEnsembleSolution(this, cloner); 57 136 } 58 59 #region IClassificationEnsembleModel Members 137 private void RegisterClassificationSolutionsEventHandler() { 138 classificationSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsAdded); 139 classificationSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsRemoved); 140 classificationSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CollectionReset); 141 } 142 143 protected override void RecalculateResults() { 144 CalculateResults(); 145 } 146 147 #region Evaluation 148 public override IEnumerable<double> EstimatedTrainingClassValues { 149 get { 150 var rows = ProblemData.TrainingIndizes; 151 var estimatedValuesEnumerators = (from model in Model.Models 152 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 153 .ToList(); 154 var rowsEnumerator = rows.GetEnumerator(); 155 // aggregate to make sure that MoveNext is called for all enumerators 156 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 157 int currentRow = rowsEnumerator.Current; 158 159 var selectedEnumerators = from pair in estimatedValuesEnumerators 160 where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model) 161 select pair.EstimatedValuesEnumerator; 162 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current)); 163 } 164 } 165 } 166 167 public override IEnumerable<double> EstimatedTestClassValues { 168 get { 169 var rows = ProblemData.TestIndizes; 170 var estimatedValuesEnumerators = (from model in Model.Models 171 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 172 .ToList(); 173 var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator(); 174 // aggregate to make sure that MoveNext is called for all enumerators 175 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 176 int currentRow = rowsEnumerator.Current; 177 178 var selectedEnumerators = from pair in estimatedValuesEnumerators 179 where RowIsTestForModel(currentRow, pair.Model) 180 select pair.EstimatedValuesEnumerator; 181 182 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current)); 183 } 184 } 185 } 186 187 private bool RowIsTrainingForModel(int currentRow, IClassificationModel model) { 188 return trainingPartitions == null || !trainingPartitions.ContainsKey(model) || 189 (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End); 190 } 191 192 private bool RowIsTestForModel(int currentRow, IClassificationModel model) { 193 return testPartitions == null || !testPartitions.ContainsKey(model) || 194 (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End); 195 } 196 197 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 198 return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows) 199 select AggregateEstimatedClassValues(xs); 200 } 60 201 61 202 public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) { 62 var estimatedValuesEnumerators = (from model in models203 var estimatedValuesEnumerators = (from model in Model.Models 63 204 select model.GetEstimatedClassValues(dataset, rows).GetEnumerator()) 64 205 .ToList(); … … 70 211 } 71 212 213 private double AggregateEstimatedClassValues(IEnumerable<double> estimatedClassValues) { 214 return estimatedClassValues 215 .GroupBy(x => x) 216 .OrderBy(g => -g.Count()) 217 .Select(g => g.Key) 218 .DefaultIfEmpty(double.NaN) 219 .First(); 220 } 72 221 #endregion 73 222 74 #region IClassificationModel Members 75 76 public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 77 foreach (var estimatedValuesVector in GetEstimatedClassValueVectors(dataset, rows)) { 78 // return the class which is most often occuring 79 yield return 80 estimatedValuesVector 81 .GroupBy(x => x) 82 .OrderBy(g => -g.Count()) 83 .Select(g => g.Key) 84 .First(); 85 } 86 } 87 88 #endregion 223 protected override void OnProblemDataChanged() { 224 IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset, 225 ProblemData.AllowedInputVariables, 226 ProblemData.TargetVariable); 227 problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start; 228 problemData.TrainingPartition.End = ProblemData.TrainingPartition.End; 229 problemData.TestPartition.Start = ProblemData.TestPartition.Start; 230 problemData.TestPartition.End = ProblemData.TestPartition.End; 231 232 foreach (var solution in ClassificationSolutions) { 233 if (solution is ClassificationEnsembleSolution) 234 solution.ProblemData = ProblemData; 235 else 236 solution.ProblemData = problemData; 237 } 238 foreach (var trainingPartition in trainingPartitions.Values) { 239 trainingPartition.Start = ProblemData.TrainingPartition.Start; 240 trainingPartition.End = ProblemData.TrainingPartition.End; 241 } 242 foreach (var testPartition in testPartitions.Values) { 243 testPartition.Start = ProblemData.TestPartition.Start; 244 testPartition.End = ProblemData.TestPartition.End; 245 } 246 247 base.OnProblemDataChanged(); 248 } 249 250 public void AddClassificationSolutions(IEnumerable<IClassificationSolution> solutions) { 251 classificationSolutions.AddRange(solutions); 252 } 253 public void RemoveClassificationSolutions(IEnumerable<IClassificationSolution> solutions) { 254 classificationSolutions.RemoveRange(solutions); 255 } 256 257 private void classificationSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) { 258 foreach (var solution in e.Items) AddClassificationSolution(solution); 259 RecalculateResults(); 260 } 261 private void classificationSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) { 262 foreach (var solution in e.Items) RemoveClassificationSolution(solution); 263 RecalculateResults(); 264 } 265 private void classificationSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) { 266 foreach (var solution in e.OldItems) RemoveClassificationSolution(solution); 267 foreach (var solution in e.Items) AddClassificationSolution(solution); 268 RecalculateResults(); 269 } 270 271 private void AddClassificationSolution(IClassificationSolution solution) { 272 if (Model.Models.Contains(solution.Model)) throw new ArgumentException(); 273 Model.Add(solution.Model); 274 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 275 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 276 } 277 278 private void RemoveClassificationSolution(IClassificationSolution solution) { 279 if (!Model.Models.Contains(solution.Model)) throw new ArgumentException(); 280 Model.Remove(solution.Model); 281 trainingPartitions.Remove(solution.Model); 282 testPartitions.Remove(solution.Model); 283 } 89 284 } 90 285 } -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationProblemData.cs
r6232 r6760 34 34 [Item("ClassificationProblemData", "Represents an item containing all data defining a classification problem.")] 35 35 public class ClassificationProblemData : DataAnalysisProblemData, IClassificationProblemData { 36 pr ivateconst string TargetVariableParameterName = "TargetVariable";37 pr ivateconst string ClassNamesParameterName = "ClassNames";38 pr ivateconst string ClassificationPenaltiesParameterName = "ClassificationPenalties";39 pr ivateconst int MaximumNumberOfClasses = 20;40 pr ivateconst int InspectedRowsToDetermineTargets = 500;36 protected const string TargetVariableParameterName = "TargetVariable"; 37 protected const string ClassNamesParameterName = "ClassNames"; 38 protected const string ClassificationPenaltiesParameterName = "ClassificationPenalties"; 39 protected const int MaximumNumberOfClasses = 20; 40 protected const int InspectedRowsToDetermineTargets = 500; 41 41 42 42 #region default data … … 171 171 {1176881,7,5,3,7,4,10,7,5,5,4 } 172 172 }; 173 private static Dataset defaultDataset; 174 private static IEnumerable<string> defaultAllowedInputVariables; 175 private static string defaultTargetVariable; 173 private static readonly Dataset defaultDataset; 174 private static readonly IEnumerable<string> defaultAllowedInputVariables; 175 private static readonly string defaultTargetVariable; 176 177 private static readonly ClassificationProblemData emptyProblemData; 178 public static ClassificationProblemData EmptyProblemData { 179 get { return EmptyProblemData; } 180 } 181 176 182 static ClassificationProblemData() { 177 183 defaultDataset = new Dataset(defaultVariableNames, defaultData); … … 181 187 defaultAllowedInputVariables = defaultVariableNames.Except(new List<string>() { "sample", "class" }); 182 188 defaultTargetVariable = "class"; 189 190 var problemData = new ClassificationProblemData(); 191 problemData.Parameters.Clear(); 192 problemData.Name = "Empty Classification ProblemData"; 193 problemData.Description = "This ProblemData acts as place holder before the correct problem data is loaded."; 194 problemData.isEmpty = true; 195 196 problemData.Parameters.Add(new FixedValueParameter<Dataset>(DatasetParameterName, "", new Dataset())); 197 problemData.Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemList<StringValue>>(InputVariablesParameterName, "")); 198 problemData.Parameters.Add(new FixedValueParameter<IntRange>(TrainingPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly())); 199 problemData.Parameters.Add(new FixedValueParameter<IntRange>(TestPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly())); 200 problemData.Parameters.Add(new ConstrainedValueParameter<StringValue>(TargetVariableParameterName, new ItemSet<StringValue>())); 201 problemData.Parameters.Add(new FixedValueParameter<StringMatrix>(ClassNamesParameterName, "", new StringMatrix(0, 0).AsReadOnly())); 202 problemData.Parameters.Add(new FixedValueParameter<DoubleMatrix>(ClassificationPenaltiesParameterName, "", (DoubleMatrix)new DoubleMatrix(0, 0).AsReadOnly())); 203 emptyProblemData = problemData; 183 204 } 184 205 #endregion 185 206 186 207 #region parameter properties 187 public IValueParameter<StringValue> TargetVariableParameter {188 get { return ( IValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }208 public ConstrainedValueParameter<StringValue> TargetVariableParameter { 209 get { return (ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; } 189 210 } 190 211 public IFixedValueParameter<StringMatrix> ClassNamesParameter { … … 205 226 get { 206 227 if (classValues == null) { 207 classValues = Dataset.Get EnumeratedVariableValues(TargetVariableParameter.Value.Value).Distinct().ToList();228 classValues = Dataset.GetDoubleValues(TargetVariableParameter.Value.Value).Distinct().ToList(); 208 229 classValues.Sort(); 209 230 } … … 249 270 RegisterParameterEvents(); 250 271 } 251 public override IDeepCloneable Clone(Cloner cloner) { return new ClassificationProblemData(this, cloner); } 272 public override IDeepCloneable Clone(Cloner cloner) { 273 if (this == emptyProblemData) return emptyProblemData; 274 return new ClassificationProblemData(this, cloner); 275 } 252 276 253 277 public ClassificationProblemData() : this(defaultDataset, defaultAllowedInputVariables, defaultTargetVariable) { } … … 267 291 private static IEnumerable<string> CheckVariablesForPossibleTargetVariables(Dataset dataset) { 268 292 int maxSamples = Math.Min(InspectedRowsToDetermineTargets, dataset.Rows); 269 var validTargetVariables = from v in dataset.VariableNames270 let DistinctValues = dataset.GetVariableValues(v)271 .Take(maxSamples)272 .Distinct()273 .Count()274 where DistinctValues < MaximumNumberOfClasses275 select v;293 var validTargetVariables = (from v in dataset.DoubleVariables 294 let distinctValues = dataset.GetDoubleValues(v) 295 .Take(maxSamples) 296 .Distinct() 297 .Count() 298 where distinctValues < MaximumNumberOfClasses 299 select v).ToArray(); 276 300 277 301 if (!validTargetVariables.Any()) … … 283 307 284 308 private void ResetTargetVariableDependentMembers() { 285 Der gisterParameterEvents();309 DeregisterParameterEvents(); 286 310 287 311 classNames = null; … … 357 381 ClassificationPenaltiesParameter.Value.ItemChanged += new EventHandler<EventArgs<int, int>>(MatrixParameter_ItemChanged); 358 382 } 359 private void Der gisterParameterEvents() {383 private void DeregisterParameterEvents() { 360 384 TargetVariableParameter.ValueChanged -= new EventHandler(TargetVariableParameter_ValueChanged); 361 385 ClassNamesParameter.Value.Reset -= new EventHandler(Parameter_ValueChanged); … … 386 410 dataset.Name = Path.GetFileName(fileName); 387 411 388 ClassificationProblemData problemData = new ClassificationProblemData(dataset, dataset. VariableNames.Skip(1), dataset.VariableNames.First());412 ClassificationProblemData problemData = new ClassificationProblemData(dataset, dataset.DoubleVariables.Skip(1), dataset.DoubleVariables.First()); 389 413 problemData.Name = "Data imported from " + Path.GetFileName(fileName); 390 414 return problemData; -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs
r6184 r6760 20 20 #endregion 21 21 22 using System;23 22 using System.Collections.Generic; 24 23 using System.Linq; 25 24 using HeuristicLab.Common; 26 using HeuristicLab.Data;27 using HeuristicLab.Optimization;28 25 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 29 26 … … 33 30 /// </summary> 34 31 [StorableClass] 35 public class ClassificationSolution : DataAnalysisSolution, IClassificationSolution { 36 private const string TrainingAccuracyResultName = "Accuracy (training)"; 37 private const string TestAccuracyResultName = "Accuracy (test)"; 38 39 public new IClassificationModel Model { 40 get { return (IClassificationModel)base.Model; } 41 protected set { base.Model = value; } 42 } 43 44 public new IClassificationProblemData ProblemData { 45 get { return (IClassificationProblemData)base.ProblemData; } 46 protected set { base.ProblemData = value; } 47 } 48 49 public double TrainingAccuracy { 50 get { return ((DoubleValue)this[TrainingAccuracyResultName].Value).Value; } 51 private set { ((DoubleValue)this[TrainingAccuracyResultName].Value).Value = value; } 52 } 53 54 public double TestAccuracy { 55 get { return ((DoubleValue)this[TestAccuracyResultName].Value).Value; } 56 private set { ((DoubleValue)this[TestAccuracyResultName].Value).Value = value; } 57 } 32 public abstract class ClassificationSolution : ClassificationSolutionBase { 33 protected readonly Dictionary<int, double> evaluationCache; 58 34 59 35 [StorableConstructor] 60 protected ClassificationSolution(bool deserializing) : base(deserializing) { } 36 protected ClassificationSolution(bool deserializing) 37 : base(deserializing) { 38 evaluationCache = new Dictionary<int, double>(); 39 } 61 40 protected ClassificationSolution(ClassificationSolution original, Cloner cloner) 62 41 : base(original, cloner) { 42 evaluationCache = new Dictionary<int, double>(original.evaluationCache); 63 43 } 64 44 public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData) 65 45 : base(model, problemData) { 66 Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue())); 67 Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue())); 68 RecalculateResults(); 46 evaluationCache = new Dictionary<int, double>(); 69 47 } 70 48 71 public override IDeepCloneable Clone(Cloner cloner) { 72 return new ClassificationSolution(this, cloner); 49 public override IEnumerable<double> EstimatedClassValues { 50 get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 51 } 52 public override IEnumerable<double> EstimatedTrainingClassValues { 53 get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); } 54 } 55 public override IEnumerable<double> EstimatedTestClassValues { 56 get { return GetEstimatedClassValues(ProblemData.TestIndizes); } 73 57 } 74 58 75 protected override void OnProblemDataChanged(EventArgs e) { 76 base.OnProblemDataChanged(e); 77 RecalculateResults(); 59 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 60 var rowsToEvaluate = rows.Except(evaluationCache.Keys); 61 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 62 var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 63 64 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 65 evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 66 } 67 68 return rows.Select(row => evaluationCache[row]); 78 69 } 79 70 80 protected override void On ModelChanged(EventArgs e) {81 base.OnModelChanged(e);82 RecalculateResults();71 protected override void OnProblemDataChanged() { 72 evaluationCache.Clear(); 73 base.OnProblemDataChanged(); 83 74 } 84 75 85 protected void RecalculateResults() { 86 double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values 87 IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 88 double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values 89 IEnumerable<double> originalTestClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 90 91 OnlineCalculatorError errorState; 92 double trainingAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTrainingClassValues, originalTrainingClassValues, out errorState); 93 if (errorState != OnlineCalculatorError.None) trainingAccuracy = double.NaN; 94 double testAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTestClassValues, originalTestClassValues, out errorState); 95 if (errorState != OnlineCalculatorError.None) testAccuracy = double.NaN; 96 97 TrainingAccuracy = trainingAccuracy; 98 TestAccuracy = testAccuracy; 99 } 100 101 public virtual IEnumerable<double> EstimatedClassValues { 102 get { 103 return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); 104 } 105 } 106 107 public virtual IEnumerable<double> EstimatedTrainingClassValues { 108 get { 109 return GetEstimatedClassValues(ProblemData.TrainingIndizes); 110 } 111 } 112 113 public virtual IEnumerable<double> EstimatedTestClassValues { 114 get { 115 return GetEstimatedClassValues(ProblemData.TestIndizes); 116 } 117 } 118 119 public virtual IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 120 return Model.GetEstimatedClassValues(ProblemData.Dataset, rows); 76 protected override void OnModelChanged() { 77 evaluationCache.Clear(); 78 base.OnModelChanged(); 121 79 } 122 80 } -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs
r5809 r6760 33 33 [StorableClass] 34 34 [Item("DiscriminantFunctionClassificationModel", "Represents a classification model that uses a discriminant function and classification thresholds.")] 35 public class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel {35 public abstract class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel { 36 36 [Storable] 37 37 private IRegressionModel model; … … 70 70 } 71 71 72 public override IDeepCloneable Clone(Cloner cloner) {73 return new DiscriminantFunctionClassificationModel(this, cloner);74 }75 76 72 public void SetThresholdsAndClassValues(IEnumerable<double> thresholds, IEnumerable<double> classValues) { 77 73 var classValuesArr = classValues.ToArray(); … … 106 102 } 107 103 #endregion 104 105 public abstract IDiscriminantFunctionClassificationSolution CreateDiscriminantFunctionClassificationSolution(IClassificationProblemData problemData); 106 public abstract IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData); 108 107 } 109 108 } -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs
r5942 r6760 20 20 #endregion 21 21 22 using System;23 22 using System.Collections.Generic; 24 23 using System.Linq; … … 26 25 using HeuristicLab.Core; 27 26 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 using HeuristicLab.Data;29 using HeuristicLab.Optimization;30 27 31 28 namespace HeuristicLab.Problems.DataAnalysis { … … 35 32 [StorableClass] 36 33 [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")] 37 public class DiscriminantFunctionClassificationSolution : ClassificationSolution, IDiscriminantFunctionClassificationSolution { 38 private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)"; 39 private const string TestMeanSquaredErrorResultName = "Mean squared error (test)"; 40 private const string TrainingRSquaredResultName = "Pearson's R² (training)"; 41 private const string TestRSquaredResultName = "Pearson's R² (test)"; 34 public abstract class DiscriminantFunctionClassificationSolution : DiscriminantFunctionClassificationSolutionBase { 35 protected readonly Dictionary<int, double> valueEvaluationCache; 36 protected readonly Dictionary<int, double> classValueEvaluationCache; 42 37 43 public new IDiscriminantFunctionClassificationModel Model { 44 get { return (IDiscriminantFunctionClassificationModel)base.Model; } 45 protected set { 46 if (value != null && value != Model) { 47 if (Model != null) { 48 Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged); 49 } 50 value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged); 51 base.Model = value; 52 } 53 } 38 [StorableConstructor] 39 protected DiscriminantFunctionClassificationSolution(bool deserializing) 40 : base(deserializing) { 41 valueEvaluationCache = new Dictionary<int, double>(); 42 classValueEvaluationCache = new Dictionary<int, double>(); 43 } 44 protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner) 45 : base(original, cloner) { 46 valueEvaluationCache = new Dictionary<int, double>(original.valueEvaluationCache); 47 classValueEvaluationCache = new Dictionary<int, double>(original.classValueEvaluationCache); 48 } 49 protected DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData) 50 : base(model, problemData) { 51 valueEvaluationCache = new Dictionary<int, double>(); 52 classValueEvaluationCache = new Dictionary<int, double>(); 53 54 SetAccuracyMaximizingThresholds(); 54 55 } 55 56 56 public double TrainingMeanSquaredError { 57 get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; } 58 private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; } 57 public override IEnumerable<double> EstimatedClassValues { 58 get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 59 } 60 public override IEnumerable<double> EstimatedTrainingClassValues { 61 get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); } 62 } 63 public override IEnumerable<double> EstimatedTestClassValues { 64 get { return GetEstimatedClassValues(ProblemData.TestIndizes); } 59 65 } 60 66 61 public double TestMeanSquaredError { 62 get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; } 63 private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; } 67 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 68 var rowsToEvaluate = rows.Except(classValueEvaluationCache.Keys); 69 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 70 var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 71 72 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 73 classValueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 74 } 75 76 return rows.Select(row => classValueEvaluationCache[row]); 64 77 } 65 78 66 public double TrainingRSquared {67 get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; }68 private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; }69 }70 79 71 public double TestRSquared { 72 get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; } 73 private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; } 74 } 75 76 [StorableConstructor] 77 protected DiscriminantFunctionClassificationSolution(bool deserializing) : base(deserializing) { } 78 protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner) 79 : base(original, cloner) { 80 RegisterEventHandler(); 81 } 82 public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData) 83 : this(new DiscriminantFunctionClassificationModel(model), problemData) { 84 } 85 public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData) 86 : base(model, problemData) { 87 Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue())); 88 Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue())); 89 Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue())); 90 Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue())); 91 RegisterEventHandler(); 92 SetAccuracyMaximizingThresholds(); 93 RecalculateResults(); 94 } 95 96 [StorableHook(HookType.AfterDeserialization)] 97 private void AfterDeserialization() { 98 RegisterEventHandler(); 99 } 100 101 protected new void RecalculateResults() { 102 double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values 103 IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 104 double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values 105 IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 106 107 OnlineCalculatorError errorState; 108 double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 109 TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN; 110 double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 111 TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN; 112 113 double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 114 TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN; 115 double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 116 TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN; 117 } 118 119 private void RegisterEventHandler() { 120 Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged); 121 } 122 private void Model_ThresholdsChanged(object sender, EventArgs e) { 123 OnModelThresholdsChanged(e); 124 } 125 126 public void SetAccuracyMaximizingThresholds() { 127 double[] classValues; 128 double[] thresholds; 129 var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 130 AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds); 131 132 Model.SetThresholdsAndClassValues(thresholds, classValues); 133 } 134 135 public void SetClassDistibutionCutPointThresholds() { 136 double[] classValues; 137 double[] thresholds; 138 var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 139 NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds); 140 141 Model.SetThresholdsAndClassValues(thresholds, classValues); 142 } 143 144 protected override void OnModelChanged(EventArgs e) { 145 base.OnModelChanged(e); 146 SetAccuracyMaximizingThresholds(); 147 RecalculateResults(); 148 } 149 150 protected override void OnProblemDataChanged(EventArgs e) { 151 base.OnProblemDataChanged(e); 152 SetAccuracyMaximizingThresholds(); 153 RecalculateResults(); 154 } 155 protected virtual void OnModelThresholdsChanged(EventArgs e) { 156 base.OnModelChanged(e); 157 RecalculateResults(); 158 } 159 160 public IEnumerable<double> EstimatedValues { 80 public override IEnumerable<double> EstimatedValues { 161 81 get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 162 82 } 163 164 public IEnumerable<double> EstimatedTrainingValues { 83 public override IEnumerable<double> EstimatedTrainingValues { 165 84 get { return GetEstimatedValues(ProblemData.TrainingIndizes); } 166 85 } 167 168 public IEnumerable<double> EstimatedTestValues { 86 public override IEnumerable<double> EstimatedTestValues { 169 87 get { return GetEstimatedValues(ProblemData.TestIndizes); } 170 88 } 171 89 172 public IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 173 return Model.GetEstimatedValues(ProblemData.Dataset, rows); 90 public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 91 var rowsToEvaluate = rows.Except(valueEvaluationCache.Keys); 92 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 93 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 94 95 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 96 valueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 97 } 98 99 return rows.Select(row => valueEvaluationCache[row]); 100 } 101 102 protected override void OnModelChanged() { 103 valueEvaluationCache.Clear(); 104 classValueEvaluationCache.Clear(); 105 base.OnModelChanged(); 106 } 107 protected override void OnModelThresholdsChanged(System.EventArgs e) { 108 classValueEvaluationCache.Clear(); 109 base.OnModelThresholdsChanged(e); 110 } 111 protected override void OnProblemDataChanged() { 112 valueEvaluationCache.Clear(); 113 classValueEvaluationCache.Clear(); 114 base.OnProblemDataChanged(); 174 115 } 175 116 } -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Clustering/ClusteringProblemData.cs
r6228 r6760 95 95 dataset.Name = Path.GetFileName(fileName); 96 96 97 ClusteringProblemData problemData = new ClusteringProblemData(dataset, dataset. VariableNames);97 ClusteringProblemData problemData = new ClusteringProblemData(dataset, dataset.DoubleVariables); 98 98 problemData.Name = "Data imported from " + Path.GetFileName(fileName); 99 99 return problemData; -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Clustering/ClusteringSolution.cs
r6184 r6760 45 45 } 46 46 47 protected override void RecalculateResults() { 48 } 49 47 50 #region IClusteringSolution Members 48 51 -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisProblem.cs
r5809 r6760 48 48 public T ProblemData { 49 49 get { return ProblemDataParameter.Value; } 50 protected set { ProblemDataParameter.Value = value; } 50 protected set { 51 ProblemDataParameter.Value = value; 52 } 51 53 } 52 54 #endregion 53 55 protected DataAnalysisProblem(DataAnalysisProblem<T> original, Cloner cloner) 54 56 : base(original, cloner) { 57 RegisterEventHandlers(); 55 58 } 56 59 [StorableConstructor] … … 59 62 : base() { 60 63 Parameters.Add(new ValueParameter<T>(ProblemDataParameterName, ProblemDataParameterDescription)); 64 RegisterEventHandlers(); 65 } 66 67 [StorableHook(HookType.AfterDeserialization)] 68 private void AfterDeserialization() { 69 RegisterEventHandlers(); 61 70 } 62 71 63 72 private void RegisterEventHandlers() { 64 ProblemDataParameter.Value.Changed += new EventHandler(ProblemDataParameter_ValueChanged); 73 ProblemDataParameter.ValueChanged += new EventHandler(ProblemDataParameter_ValueChanged); 74 if (ProblemDataParameter.Value != null) ProblemDataParameter.Value.Changed += new EventHandler(ProblemData_Changed); 65 75 } 76 66 77 private void ProblemDataParameter_ValueChanged(object sender, EventArgs e) { 78 ProblemDataParameter.Value.Changed += new EventHandler(ProblemData_Changed); 67 79 OnProblemDataChanged(); 80 OnReset(); 81 } 82 83 private void ProblemData_Changed(object sender, EventArgs e) { 68 84 OnReset(); 69 85 } -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisProblemData.cs
r5847 r6760 33 33 [StorableClass] 34 34 public abstract class DataAnalysisProblemData : ParameterizedNamedItem, IDataAnalysisProblemData { 35 pr ivateconst string DatasetParameterName = "Dataset";36 pr ivateconst string InputVariablesParameterName = "InputVariables";37 pr ivateconst string TrainingPartitionParameterName = "TrainingPartition";38 pr ivateconst string TestPartitionParameterName = "TestPartition";35 protected const string DatasetParameterName = "Dataset"; 36 protected const string InputVariablesParameterName = "InputVariables"; 37 protected const string TrainingPartitionParameterName = "TrainingPartition"; 38 protected const string TestPartitionParameterName = "TestPartition"; 39 39 40 40 #region parameter properites … … 53 53 #endregion 54 54 55 #region propeties 55 #region properties 56 protected bool isEmpty = false; 57 public bool IsEmpty { 58 get { return isEmpty; } 59 } 56 60 public Dataset Dataset { 57 61 get { return DatasetParameter.Value; } … … 71 75 } 72 76 73 public IEnumerable<int> TrainingIndizes {77 public virtual IEnumerable<int> TrainingIndizes { 74 78 get { 75 79 return Enumerable.Range(TrainingPartition.Start, TrainingPartition.End - TrainingPartition.Start) 76 .Where( i => i >= 0 && i < Dataset.Rows && (i < TestPartition.Start || TestPartition.End <= i));80 .Where(IsTrainingSample); 77 81 } 78 82 } 79 public IEnumerable<int> TestIndizes {83 public virtual IEnumerable<int> TestIndizes { 80 84 get { 81 85 return Enumerable.Range(TestPartition.Start, TestPartition.End - TestPartition.Start) 82 .Where( i => i >= 0 && i < Dataset.Rows);86 .Where(IsTestSample); 83 87 } 88 } 89 90 public virtual bool IsTrainingSample(int index) { 91 return index >= 0 && index < Dataset.Rows && 92 TrainingPartition.Start <= index && index < TrainingPartition.End && 93 (index < TestPartition.Start || TestPartition.End <= index); 94 } 95 96 public virtual bool IsTestSample(int index) { 97 return index >= 0 && index < Dataset.Rows && 98 TestPartition.Start <= index && index < TestPartition.End; 84 99 } 85 100 #endregion 86 101 87 protected DataAnalysisProblemData(DataAnalysisProblemData original, Cloner cloner) : base(original, cloner) { } 102 protected DataAnalysisProblemData(DataAnalysisProblemData original, Cloner cloner) 103 : base(original, cloner) { 104 isEmpty = original.isEmpty; 105 RegisterEventHandlers(); 106 } 88 107 [StorableConstructor] 89 108 protected DataAnalysisProblemData(bool deserializing) : base(deserializing) { } 109 [StorableHook(HookType.AfterDeserialization)] 110 private void AfterDeserialization() { 111 RegisterEventHandlers(); 112 } 90 113 91 114 protected DataAnalysisProblemData(Dataset dataset, IEnumerable<string> allowedInputVariables) { … … 93 116 if (allowedInputVariables == null) throw new ArgumentNullException("The allowedInputVariables must not be null."); 94 117 95 if (allowedInputVariables.Except(dataset. VariableNames).Any())96 throw new ArgumentException("All allowed input variables must be present in the dataset .");118 if (allowedInputVariables.Except(dataset.DoubleVariables).Any()) 119 throw new ArgumentException("All allowed input variables must be present in the dataset and of type double."); 97 120 98 var inputVariables = new CheckedItemList<StringValue>(dataset. VariableNames.Select(x => new StringValue(x)));121 var inputVariables = new CheckedItemList<StringValue>(dataset.DoubleVariables.Select(x => new StringValue(x))); 99 122 foreach (StringValue x in inputVariables) 100 123 inputVariables.SetItemCheckedState(x, allowedInputVariables.Contains(x.Value)); -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisSolution.cs
r5914 r6760 48 48 if (value != null) { 49 49 this[ModelResultName].Value = value; 50 OnModelChanged( EventArgs.Empty);50 OnModelChanged(); 51 51 } 52 52 } … … 56 56 public IDataAnalysisProblemData ProblemData { 57 57 get { return (IDataAnalysisProblemData)this[ProblemDataResultName].Value; } 58 protectedset {58 set { 59 59 if (this[ProblemDataResultName].Value != value) { 60 60 if (value != null) { … … 62 62 this[ProblemDataResultName].Value = value; 63 63 ProblemData.Changed += new EventHandler(ProblemData_Changed); 64 OnProblemDataChanged( EventArgs.Empty);64 OnProblemDataChanged(); 65 65 } 66 66 } … … 80 80 name = ItemName; 81 81 description = ItemDescription; 82 Add(new Result(ModelResultName, "The symbolicdata analysis model.", model));83 Add(new Result(ProblemDataResultName, "The symbolicdata analysis problem data.", problemData));82 Add(new Result(ModelResultName, "The data analysis model.", model)); 83 Add(new Result(ProblemDataResultName, "The data analysis problem data.", problemData)); 84 84 85 85 problemData.Changed += new EventHandler(ProblemData_Changed); 86 86 } 87 87 88 protected abstract void RecalculateResults(); 89 88 90 private void ProblemData_Changed(object sender, EventArgs e) { 89 OnProblemDataChanged( e);91 OnProblemDataChanged(); 90 92 } 91 93 92 94 public event EventHandler ModelChanged; 93 protected virtual void OnModelChanged(EventArgs e) { 95 protected virtual void OnModelChanged() { 96 RecalculateResults(); 94 97 var listeners = ModelChanged; 95 if (listeners != null) listeners(this, e);98 if (listeners != null) listeners(this, EventArgs.Empty); 96 99 } 97 100 98 101 public event EventHandler ProblemDataChanged; 99 protected virtual void OnProblemDataChanged(EventArgs e) { 102 protected virtual void OnProblemDataChanged() { 103 RecalculateResults(); 100 104 var listeners = ProblemDataChanged; 101 if (listeners != null) listeners(this, e);105 if (listeners != null) listeners(this, EventArgs.Empty); 102 106 } 103 107 -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs
r5809 r6760 34 34 public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel { 35 35 36 [Storable]37 36 private List<IRegressionModel> models; 38 37 public IEnumerable<IRegressionModel> Models { 39 38 get { return new List<IRegressionModel>(models); } 40 39 } 40 41 [Storable(Name = "Models")] 42 private IEnumerable<IRegressionModel> StorableModels { 43 get { return models; } 44 set { models = value.ToList(); } 45 } 46 47 #region backwards compatiblity 3.3.5 48 [Storable(Name = "models", AllowOneWay = true)] 49 private List<IRegressionModel> OldStorableModels { 50 set { models = value; } 51 } 52 #endregion 53 41 54 [StorableConstructor] 42 55 protected RegressionEnsembleModel(bool deserializing) : base(deserializing) { } … … 45 58 this.models = original.Models.Select(m => cloner.Clone(m)).ToList(); 46 59 } 60 61 public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { } 47 62 public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) 48 63 : base() { … … 57 72 58 73 #region IRegressionEnsembleModel Members 74 75 public void Add(IRegressionModel model) { 76 models.Add(model); 77 } 78 public void Remove(IRegressionModel model) { 79 models.Remove(model); 80 } 59 81 60 82 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(Dataset dataset, IEnumerable<int> rows) { … … 79 101 } 80 102 103 public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) { 104 return new RegressionEnsembleSolution(this.Models, problemData); 105 } 106 IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) { 107 return CreateRegressionSolution(problemData); 108 } 109 81 110 #endregion 82 111 } -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs
r6184 r6760 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; 25 using HeuristicLab.Collections; 24 26 using HeuristicLab.Common; 25 27 using HeuristicLab.Core; 28 using HeuristicLab.Data; 26 29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 27 using System;28 using HeuristicLab.Data;29 30 30 31 namespace HeuristicLab.Problems.DataAnalysis { … … 34 35 [StorableClass] 35 36 [Item("Regression Ensemble Solution", "A regression solution that contains an ensemble of multiple regression models")] 36 // [Creatable("Data Analysis")]37 public class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution {37 [Creatable("Data Analysis - Ensembles")] 38 public sealed class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution { 38 39 public new IRegressionEnsembleModel Model { 39 40 get { return (IRegressionEnsembleModel)base.Model; } 41 } 42 43 public new RegressionEnsembleProblemData ProblemData { 44 get { return (RegressionEnsembleProblemData)base.ProblemData; } 45 set { base.ProblemData = value; } 46 } 47 48 private readonly ItemCollection<IRegressionSolution> regressionSolutions; 49 public IItemCollection<IRegressionSolution> RegressionSolutions { 50 get { return regressionSolutions; } 40 51 } 41 52 … … 46 57 47 58 [StorableConstructor] 48 protected RegressionEnsembleSolution(bool deserializing) : base(deserializing) { } 49 protected RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner) 59 private RegressionEnsembleSolution(bool deserializing) 60 : base(deserializing) { 61 regressionSolutions = new ItemCollection<IRegressionSolution>(); 62 } 63 [StorableHook(HookType.AfterDeserialization)] 64 private void AfterDeserialization() { 65 foreach (var model in Model.Models) { 66 IRegressionProblemData problemData = (IRegressionProblemData) ProblemData.Clone(); 67 problemData.TrainingPartition.Start = trainingPartitions[model].Start; 68 problemData.TrainingPartition.End = trainingPartitions[model].End; 69 problemData.TestPartition.Start = testPartitions[model].Start; 70 problemData.TestPartition.End = testPartitions[model].End; 71 72 regressionSolutions.Add(model.CreateRegressionSolution(problemData)); 73 } 74 RegisterRegressionSolutionsEventHandler(); 75 } 76 77 private RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner) 50 78 : base(original, cloner) { 51 }52 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)53 : base(new RegressionEnsembleModel(models), problemData) {54 79 trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 55 80 testPartitions = new Dictionary<IRegressionModel, IntRange>(); 56 foreach (var model in models) { 57 trainingPartitions[model] = (IntRange)problemData.TrainingPartition.Clone(); 58 testPartitions[model] = (IntRange)problemData.TestPartition.Clone(); 59 } 60 RecalculateResults(); 61 } 81 foreach (var pair in original.trainingPartitions) { 82 trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value); 83 } 84 foreach (var pair in original.testPartitions) { 85 testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value); 86 } 87 88 regressionSolutions = cloner.Clone(original.regressionSolutions); 89 RegisterRegressionSolutionsEventHandler(); 90 } 91 92 public RegressionEnsembleSolution() 93 : base(new RegressionEnsembleModel(), RegressionEnsembleProblemData.EmptyProblemData) { 94 trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 95 testPartitions = new Dictionary<IRegressionModel, IntRange>(); 96 regressionSolutions = new ItemCollection<IRegressionSolution>(); 97 98 RegisterRegressionSolutionsEventHandler(); 99 } 100 101 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData) 102 : this(models, problemData, 103 models.Select(m => (IntRange)problemData.TrainingPartition.Clone()), 104 models.Select(m => (IntRange)problemData.TestPartition.Clone()) 105 ) { } 62 106 63 107 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions) 64 : base(new RegressionEnsembleModel( models), problemData) {108 : base(new RegressionEnsembleModel(Enumerable.Empty<IRegressionModel>()), new RegressionEnsembleProblemData(problemData)) { 65 109 this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 66 110 this.testPartitions = new Dictionary<IRegressionModel, IntRange>(); 111 this.regressionSolutions = new ItemCollection<IRegressionSolution>(); 112 113 List<IRegressionSolution> solutions = new List<IRegressionSolution>(); 67 114 var modelEnumerator = models.GetEnumerator(); 68 115 var trainingPartitionEnumerator = trainingPartitions.GetEnumerator(); 69 116 var testPartitionEnumerator = testPartitions.GetEnumerator(); 117 70 118 while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) { 71 this.trainingPartitions[modelEnumerator.Current] = (IntRange)trainingPartitionEnumerator.Current.Clone(); 72 this.testPartitions[modelEnumerator.Current] = (IntRange)testPartitionEnumerator.Current.Clone(); 119 var p = (IRegressionProblemData)problemData.Clone(); 120 p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start; 121 p.TrainingPartition.End = trainingPartitionEnumerator.Current.End; 122 p.TestPartition.Start = testPartitionEnumerator.Current.Start; 123 p.TestPartition.End = testPartitionEnumerator.Current.End; 124 125 solutions.Add(modelEnumerator.Current.CreateRegressionSolution(p)); 73 126 } 74 127 if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) { … … 76 129 } 77 130 78 RecalculateResults(); 79 } 80 81 private void RecalculateResults() { 82 double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values 83 var trainingIndizes = Enumerable.Range(ProblemData.TrainingPartition.Start, 84 ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start); 85 IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, trainingIndizes); 86 double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values 87 IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 88 89 OnlineCalculatorError errorState; 90 double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 91 TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN; 92 double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 93 TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN; 94 95 double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 96 TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN; 97 double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 98 TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN; 99 100 double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 101 TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN; 102 double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 103 TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN; 104 105 double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 106 TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN; 107 double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 108 TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN; 131 RegisterRegressionSolutionsEventHandler(); 132 regressionSolutions.AddRange(solutions); 109 133 } 110 134 … … 112 136 return new RegressionEnsembleSolution(this, cloner); 113 137 } 114 138 private void RegisterRegressionSolutionsEventHandler() { 139 regressionSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsAdded); 140 regressionSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsRemoved); 141 regressionSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_CollectionReset); 142 } 143 144 protected override void RecalculateResults() { 145 CalculateResults(); 146 } 147 148 #region Evaluation 115 149 public override IEnumerable<double> EstimatedTrainingValues { 116 150 get { 117 var rows = Enumerable.Range(ProblemData.TrainingPartition.Start, ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);151 var rows = ProblemData.TrainingIndizes; 118 152 var estimatedValuesEnumerators = (from model in Model.Models 119 153 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() }) 120 154 .ToList(); 121 155 var rowsEnumerator = rows.GetEnumerator(); 156 // aggregate to make sure that MoveNext is called for all enumerators 122 157 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 123 158 int currentRow = rowsEnumerator.Current; 124 159 125 160 var selectedEnumerators = from pair in estimatedValuesEnumerators 126 where trainingPartitions == null || !trainingPartitions.ContainsKey(pair.Model) || 127 (trainingPartitions[pair.Model].Start <= currentRow && currentRow < trainingPartitions[pair.Model].End) 161 where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model) 128 162 select pair.EstimatedValuesEnumerator; 129 163 yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current)); … … 134 168 public override IEnumerable<double> EstimatedTestValues { 135 169 get { 170 var rows = ProblemData.TestIndizes; 136 171 var estimatedValuesEnumerators = (from model in Model.Models 137 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, ProblemData.TestIndizes).GetEnumerator() })172 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() }) 138 173 .ToList(); 139 174 var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator(); 175 // aggregate to make sure that MoveNext is called for all enumerators 140 176 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 141 177 int currentRow = rowsEnumerator.Current; 142 178 143 179 var selectedEnumerators = from pair in estimatedValuesEnumerators 144 where testPartitions == null || !testPartitions.ContainsKey(pair.Model) || 145 (testPartitions[pair.Model].Start <= currentRow && currentRow < testPartitions[pair.Model].End) 180 where RowIsTestForModel(currentRow, pair.Model) 146 181 select pair.EstimatedValuesEnumerator; 147 182 … … 149 184 } 150 185 } 186 } 187 188 private bool RowIsTrainingForModel(int currentRow, IRegressionModel model) { 189 return trainingPartitions == null || !trainingPartitions.ContainsKey(model) || 190 (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End); 191 } 192 193 private bool RowIsTestForModel(int currentRow, IRegressionModel model) { 194 return testPartitions == null || !testPartitions.ContainsKey(model) || 195 (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End); 151 196 } 152 197 … … 168 213 169 214 private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) { 170 return estimatedValues.Average(); 171 } 172 173 //[Storable] 174 //private string name; 175 //public string Name { 176 // get { 177 // return name; 178 // } 179 // set { 180 // if (value != null && value != name) { 181 // var cancelEventArgs = new CancelEventArgs<string>(value); 182 // OnNameChanging(cancelEventArgs); 183 // if (cancelEventArgs.Cancel == false) { 184 // name = value; 185 // OnNamedChanged(EventArgs.Empty); 186 // } 187 // } 188 // } 189 //} 190 191 //public bool CanChangeName { 192 // get { return true; } 193 //} 194 195 //[Storable] 196 //private string description; 197 //public string Description { 198 // get { 199 // return description; 200 // } 201 // set { 202 // if (value != null && value != description) { 203 // description = value; 204 // OnDescriptionChanged(EventArgs.Empty); 205 // } 206 // } 207 //} 208 209 //public bool CanChangeDescription { 210 // get { return true; } 211 //} 212 213 //#region events 214 //public event EventHandler<CancelEventArgs<string>> NameChanging; 215 //private void OnNameChanging(CancelEventArgs<string> cancelEventArgs) { 216 // var listener = NameChanging; 217 // if (listener != null) listener(this, cancelEventArgs); 218 //} 219 220 //public event EventHandler NameChanged; 221 //private void OnNamedChanged(EventArgs e) { 222 // var listener = NameChanged; 223 // if (listener != null) listener(this, e); 224 //} 225 226 //public event EventHandler DescriptionChanged; 227 //private void OnDescriptionChanged(EventArgs e) { 228 // var listener = DescriptionChanged; 229 // if (listener != null) listener(this, e); 230 //} 231 // #endregion 215 return estimatedValues.DefaultIfEmpty(double.NaN).Average(); 216 } 217 #endregion 218 219 protected override void OnProblemDataChanged() { 220 IRegressionProblemData problemData = new RegressionProblemData(ProblemData.Dataset, 221 ProblemData.AllowedInputVariables, 222 ProblemData.TargetVariable); 223 problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start; 224 problemData.TrainingPartition.End = ProblemData.TrainingPartition.End; 225 problemData.TestPartition.Start = ProblemData.TestPartition.Start; 226 problemData.TestPartition.End = ProblemData.TestPartition.End; 227 228 foreach (var solution in RegressionSolutions) { 229 if (solution is RegressionEnsembleSolution) 230 solution.ProblemData = ProblemData; 231 else 232 solution.ProblemData = problemData; 233 } 234 foreach (var trainingPartition in trainingPartitions.Values) { 235 trainingPartition.Start = ProblemData.TrainingPartition.Start; 236 trainingPartition.End = ProblemData.TrainingPartition.End; 237 } 238 foreach (var testPartition in testPartitions.Values) { 239 testPartition.Start = ProblemData.TestPartition.Start; 240 testPartition.End = ProblemData.TestPartition.End; 241 } 242 243 base.OnProblemDataChanged(); 244 } 245 246 public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 247 regressionSolutions.AddRange(solutions); 248 } 249 public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 250 regressionSolutions.RemoveRange(solutions); 251 } 252 253 private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 254 foreach (var solution in e.Items) AddRegressionSolution(solution); 255 RecalculateResults(); 256 } 257 private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 258 foreach (var solution in e.Items) RemoveRegressionSolution(solution); 259 RecalculateResults(); 260 } 261 private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 262 foreach (var solution in e.OldItems) RemoveRegressionSolution(solution); 263 foreach (var solution in e.Items) AddRegressionSolution(solution); 264 RecalculateResults(); 265 } 266 267 private void AddRegressionSolution(IRegressionSolution solution) { 268 if (Model.Models.Contains(solution.Model)) throw new ArgumentException(); 269 Model.Add(solution.Model); 270 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 271 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 272 } 273 274 private void RemoveRegressionSolution(IRegressionSolution solution) { 275 if (!Model.Models.Contains(solution.Model)) throw new ArgumentException(); 276 Model.Remove(solution.Model); 277 trainingPartitions.Remove(solution.Model); 278 testPartitions.Remove(solution.Model); 279 } 232 280 } 233 281 } -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionProblemData.cs
r5809 r6760 33 33 [StorableClass] 34 34 [Item("RegressionProblemData", "Represents an item containing all data defining a regression problem.")] 35 public sealedclass RegressionProblemData : DataAnalysisProblemData, IRegressionProblemData {36 pr ivateconst string TargetVariableParameterName = "TargetVariable";35 public class RegressionProblemData : DataAnalysisProblemData, IRegressionProblemData { 36 protected const string TargetVariableParameterName = "TargetVariable"; 37 37 38 38 #region default data … … 64 64 {0.83763905, 0.468046718} 65 65 }; 66 private static Dataset defaultDataset; 67 private static IEnumerable<string> defaultAllowedInputVariables; 68 private static string defaultTargetVariable; 66 private static readonly Dataset defaultDataset; 67 private static readonly IEnumerable<string> defaultAllowedInputVariables; 68 private static readonly string defaultTargetVariable; 69 70 private static readonly RegressionProblemData emptyProblemData; 71 public static RegressionProblemData EmptyProblemData { 72 get { return emptyProblemData; } 73 } 69 74 70 75 static RegressionProblemData() { … … 74 79 defaultAllowedInputVariables = new List<string>() { "x" }; 75 80 defaultTargetVariable = "y"; 81 82 var problemData = new RegressionProblemData(); 83 problemData.Parameters.Clear(); 84 problemData.Name = "Empty Regression ProblemData"; 85 problemData.Description = "This ProblemData acts as place holder before the correct problem data is loaded."; 86 problemData.isEmpty = true; 87 88 problemData.Parameters.Add(new FixedValueParameter<Dataset>(DatasetParameterName, "", new Dataset())); 89 problemData.Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemList<StringValue>>(InputVariablesParameterName, "")); 90 problemData.Parameters.Add(new FixedValueParameter<IntRange>(TrainingPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly())); 91 problemData.Parameters.Add(new FixedValueParameter<IntRange>(TestPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly())); 92 problemData.Parameters.Add(new ConstrainedValueParameter<StringValue>(TargetVariableParameterName, new ItemSet<StringValue>())); 93 emptyProblemData = problemData; 76 94 } 77 95 #endregion 78 96 79 public IValueParameter<StringValue> TargetVariableParameter {80 get { return ( IValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }97 public ConstrainedValueParameter<StringValue> TargetVariableParameter { 98 get { return (ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; } 81 99 } 82 100 public string TargetVariable { … … 85 103 86 104 [StorableConstructor] 87 pr ivateRegressionProblemData(bool deserializing) : base(deserializing) { }105 protected RegressionProblemData(bool deserializing) : base(deserializing) { } 88 106 [StorableHook(HookType.AfterDeserialization)] 89 107 private void AfterDeserialization() { … … 91 109 } 92 110 93 94 private RegressionProblemData(RegressionProblemData original, Cloner cloner) 111 protected RegressionProblemData(RegressionProblemData original, Cloner cloner) 95 112 : base(original, cloner) { 96 113 RegisterParameterEvents(); 97 114 } 98 public override IDeepCloneable Clone(Cloner cloner) { return new RegressionProblemData(this, cloner); } 115 public override IDeepCloneable Clone(Cloner cloner) { 116 if (this == emptyProblemData) return emptyProblemData; 117 return new RegressionProblemData(this, cloner); 118 } 99 119 100 120 public RegressionProblemData() … … 124 144 dataset.Name = Path.GetFileName(fileName); 125 145 126 RegressionProblemData problemData = new RegressionProblemData(dataset, dataset. VariableNames.Skip(1), dataset.VariableNames.First());146 RegressionProblemData problemData = new RegressionProblemData(dataset, dataset.DoubleVariables.Skip(1), dataset.DoubleVariables.First()); 127 147 problemData.Name = "Data imported from " + Path.GetFileName(fileName); 128 148 return problemData; -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolution.cs
r6184 r6760 20 20 #endregion 21 21 22 using System;23 22 using System.Collections.Generic; 24 23 using System.Linq; 25 24 using HeuristicLab.Common; 26 using HeuristicLab.Data;27 using HeuristicLab.Optimization;28 25 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 29 26 … … 33 30 /// </summary> 34 31 [StorableClass] 35 public class RegressionSolution : DataAnalysisSolution, IRegressionSolution { 36 private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)"; 37 private const string TestMeanSquaredErrorResultName = "Mean squared error (test)"; 38 private const string TrainingSquaredCorrelationResultName = "Pearson's R² (training)"; 39 private const string TestSquaredCorrelationResultName = "Pearson's R² (test)"; 40 private const string TrainingRelativeErrorResultName = "Average relative error (training)"; 41 private const string TestRelativeErrorResultName = "Average relative error (test)"; 42 private const string TrainingNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (training)"; 43 private const string TestNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (test)"; 32 public abstract class RegressionSolution : RegressionSolutionBase { 33 protected readonly Dictionary<int, double> evaluationCache; 44 34 45 public new IRegressionModel Model { 46 get { return (IRegressionModel)base.Model; } 47 protected set { base.Model = value; } 35 [StorableConstructor] 36 protected RegressionSolution(bool deserializing) 37 : base(deserializing) { 38 evaluationCache = new Dictionary<int, double>(); 39 } 40 protected RegressionSolution(RegressionSolution original, Cloner cloner) 41 : base(original, cloner) { 42 evaluationCache = new Dictionary<int, double>(original.evaluationCache); 43 } 44 protected RegressionSolution(IRegressionModel model, IRegressionProblemData problemData) 45 : base(model, problemData) { 46 evaluationCache = new Dictionary<int, double>(); 48 47 } 49 48 50 public new IRegressionProblemData ProblemData { 51 get { return (IRegressionProblemData)base.ProblemData; } 52 protected set { base.ProblemData = value; } 49 protected override void RecalculateResults() { 50 CalculateResults(); 53 51 } 54 52 55 public double TrainingMeanSquaredError { 56 get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; } 57 protected set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; } 53 public override IEnumerable<double> EstimatedValues { 54 get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 55 } 56 public override IEnumerable<double> EstimatedTrainingValues { 57 get { return GetEstimatedValues(ProblemData.TrainingIndizes); } 58 } 59 public override IEnumerable<double> EstimatedTestValues { 60 get { return GetEstimatedValues(ProblemData.TestIndizes); } 58 61 } 59 62 60 public double TestMeanSquaredError { 61 get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; } 62 protected set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; } 63 public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 64 var rowsToEvaluate = rows.Except(evaluationCache.Keys); 65 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 66 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 67 68 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 69 evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 70 } 71 72 return rows.Select(row => evaluationCache[row]); 63 73 } 64 74 65 p ublic double TrainingRSquared{66 get { return ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value; }67 protected set { ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value = value; }75 protected override void OnProblemDataChanged() { 76 evaluationCache.Clear(); 77 base.OnProblemDataChanged(); 68 78 } 69 79 70 public double TestRSquared { 71 get { return ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value; } 72 protected set { ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value = value; } 73 } 74 75 public double TrainingRelativeError { 76 get { return ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value; } 77 protected set { ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value = value; } 78 } 79 80 public double TestRelativeError { 81 get { return ((DoubleValue)this[TestRelativeErrorResultName].Value).Value; } 82 protected set { ((DoubleValue)this[TestRelativeErrorResultName].Value).Value = value; } 83 } 84 85 public double TrainingNormalizedMeanSquaredError { 86 get { return ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value; } 87 protected set { ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value = value; } 88 } 89 90 public double TestNormalizedMeanSquaredError { 91 get { return ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value; } 92 protected set { ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value = value; } 93 } 94 95 96 [StorableConstructor] 97 protected RegressionSolution(bool deserializing) : base(deserializing) { } 98 protected RegressionSolution(RegressionSolution original, Cloner cloner) 99 : base(original, cloner) { 100 } 101 public RegressionSolution(IRegressionModel model, IRegressionProblemData problemData) 102 : base(model, problemData) { 103 Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue())); 104 Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue())); 105 Add(new Result(TrainingSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue())); 106 Add(new Result(TestSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue())); 107 Add(new Result(TrainingRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the training partition", new PercentValue())); 108 Add(new Result(TestRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the test partition", new PercentValue())); 109 Add(new Result(TrainingNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the training partition", new DoubleValue())); 110 Add(new Result(TestNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the test partition", new DoubleValue())); 111 112 RecalculateResults(); 113 } 114 115 public override IDeepCloneable Clone(Cloner cloner) { 116 return new RegressionSolution(this, cloner); 117 } 118 119 protected override void OnProblemDataChanged(EventArgs e) { 120 base.OnProblemDataChanged(e); 121 RecalculateResults(); 122 } 123 protected override void OnModelChanged(EventArgs e) { 124 base.OnModelChanged(e); 125 RecalculateResults(); 126 } 127 128 private void RecalculateResults() { 129 double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values 130 IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 131 double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values 132 IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 133 134 OnlineCalculatorError errorState; 135 double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 136 TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN; 137 double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 138 TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN; 139 140 double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 141 TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN; 142 double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 143 TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN; 144 145 double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 146 TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN; 147 double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 148 TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN; 149 150 double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 151 TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN; 152 double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 153 TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN; 154 } 155 156 public virtual IEnumerable<double> EstimatedValues { 157 get { 158 return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); 159 } 160 } 161 162 public virtual IEnumerable<double> EstimatedTrainingValues { 163 get { 164 return GetEstimatedValues(ProblemData.TrainingIndizes); 165 } 166 } 167 168 public virtual IEnumerable<double> EstimatedTestValues { 169 get { 170 return GetEstimatedValues(ProblemData.TestIndizes); 171 } 172 } 173 174 public virtual IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 175 return Model.GetEstimatedValues(ProblemData.Dataset, rows); 80 protected override void OnModelChanged() { 81 evaluationCache.Clear(); 82 base.OnModelChanged(); 176 83 } 177 84 }
Note: See TracChangeset
for help on using the changeset viewer.