- Timestamp:
- 08/01/11 17:48:53 (13 years ago)
- Location:
- branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation
- Files:
-
- 14 edited
- 3 copied
Legend:
- Unmodified
- Added
- Removed
-
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleModel.cs
r6239 r6618 58 58 59 59 #region IClassificationEnsembleModel Members 60 public void Add(IClassificationModel model) { 61 models.Add(model); 62 } 63 public void Remove(IClassificationModel model) { 64 models.Remove(model); 65 } 60 66 61 67 public IEnumerable<IEnumerable<double>> GetEstimatedClassValueVectors(Dataset dataset, IEnumerable<int> rows) { … … 86 92 } 87 93 94 IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) { 95 return new ClassificationEnsembleSolution(models, problemData); 96 } 88 97 #endregion 89 98 } -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleProblemData.cs
r6239 r6618 37 37 public override IEnumerable<int> TrainingIndizes { 38 38 get { 39 return Enumerable.Range(TrainingPartition.Start, T estPartition.End - TestPartition.Start);39 return Enumerable.Range(TrainingPartition.Start, TrainingPartition.End - TrainingPartition.Start); 40 40 } 41 41 } -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
r6377 r6618 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; 25 using HeuristicLab.Collections; 24 26 using HeuristicLab.Common; 25 27 using HeuristicLab.Core; 28 using HeuristicLab.Data; 26 29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 27 using HeuristicLab.Data;28 using System;29 30 30 31 namespace HeuristicLab.Problems.DataAnalysis { … … 35 36 [Item("Classification Ensemble Solution", "A classification solution that contains an ensemble of multiple classification models")] 36 37 // [Creatable("Data Analysis")] 37 public class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution { 38 38 public sealed class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution { 39 39 public new IClassificationEnsembleModel Model { 40 set { base.Model = value; }41 40 get { return (IClassificationEnsembleModel)base.Model; } 41 } 42 43 private readonly ItemCollection<IClassificationSolution> classificationSolutions; 44 public IItemCollection<IClassificationSolution> ClassificationSolutions { 45 get { return classificationSolutions; } 42 46 } 43 47 … … 47 51 private Dictionary<IClassificationModel, IntRange> testPartitions; 48 52 49 50 53 [StorableConstructor] 51 protected ClassificationEnsembleSolution(bool deserializing) : base(deserializing) { } 52 protected ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner) 54 private ClassificationEnsembleSolution(bool deserializing) 55 : base(deserializing) { 56 classificationSolutions = new ItemCollection<IClassificationSolution>(); 57 } 58 [StorableHook(HookType.AfterDeserialization)] 59 private void AfterDeserialization() { 60 foreach (var model in Model.Models) { 61 IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone(); 62 problemData.TrainingPartition.Start = trainingPartitions[model].Start; 63 problemData.TrainingPartition.End = trainingPartitions[model].End; 64 problemData.TestPartition.Start = testPartitions[model].Start; 65 problemData.TestPartition.End = testPartitions[model].End; 66 67 classificationSolutions.Add(model.CreateClassificationSolution(problemData)); 68 } 69 RegisterClassificationSolutionsEventHandler(); 70 } 71 72 private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner) 53 73 : base(original, cloner) { 54 74 trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); … … 60 80 testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value); 61 81 } 62 RecalculateResults(); 63 } 82 83 classificationSolutions = cloner.Clone(original.classificationSolutions); 84 RegisterClassificationSolutionsEventHandler(); 85 } 86 64 87 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData) 65 : base(new ClassificationEnsembleModel(models), new ClassificationEnsembleProblemData(problemData)) { 66 this.name = ItemName; 67 this.description = ItemDescription; 68 trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 69 testPartitions = new Dictionary<IClassificationModel, IntRange>(); 70 foreach (var model in models) { 71 trainingPartitions[model] = (IntRange)problemData.TrainingPartition.Clone(); 72 testPartitions[model] = (IntRange)problemData.TestPartition.Clone(); 73 } 74 RecalculateResults(); 75 } 88 : this(models, problemData, 89 models.Select(m => (IntRange)problemData.TrainingPartition.Clone()), 90 models.Select(m => (IntRange)problemData.TestPartition.Clone()) 91 ) { } 76 92 77 93 public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions) 78 : base(new ClassificationEnsembleModel( models), new ClassificationEnsembleProblemData(problemData)) {94 : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) { 79 95 this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>(); 80 96 this.testPartitions = new Dictionary<IClassificationModel, IntRange>(); 97 this.classificationSolutions = new ItemCollection<IClassificationSolution>(); 98 99 List<IClassificationSolution> solutions = new List<IClassificationSolution>(); 81 100 var modelEnumerator = models.GetEnumerator(); 82 101 var trainingPartitionEnumerator = trainingPartitions.GetEnumerator(); 83 102 var testPartitionEnumerator = testPartitions.GetEnumerator(); 103 84 104 while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) { 85 this.trainingPartitions[modelEnumerator.Current] = (IntRange)trainingPartitionEnumerator.Current.Clone(); 86 this.testPartitions[modelEnumerator.Current] = (IntRange)testPartitionEnumerator.Current.Clone(); 105 var p = (IClassificationProblemData)problemData.Clone(); 106 p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start; 107 p.TrainingPartition.End = trainingPartitionEnumerator.Current.End; 108 p.TestPartition.Start = testPartitionEnumerator.Current.Start; 109 p.TestPartition.End = testPartitionEnumerator.Current.End; 110 111 solutions.Add(modelEnumerator.Current.CreateClassificationSolution(p)); 87 112 } 88 113 if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) { 89 114 throw new ArgumentException(); 90 115 } 91 RecalculateResults(); 116 117 RegisterClassificationSolutionsEventHandler(); 118 classificationSolutions.AddRange(solutions); 92 119 } 93 120 … … 95 122 return new ClassificationEnsembleSolution(this, cloner); 96 123 } 97 124 private void RegisterClassificationSolutionsEventHandler() { 125 classificationSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsAdded); 126 classificationSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_ItemsRemoved); 127 classificationSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IClassificationSolution>(classificationSolutions_CollectionReset); 128 } 129 130 protected override void RecalculateResults() { 131 CalculateResults(); 132 } 133 134 #region Evaluation 98 135 public override IEnumerable<double> EstimatedTrainingClassValues { 99 136 get { … … 169 206 .First(); 170 207 } 208 #endregion 209 210 public void AddClassificationSolutions(IEnumerable<IClassificationSolution> solutions) { 211 classificationSolutions.AddRange(solutions); 212 } 213 public void RemoveClassificationSolutions(IEnumerable<IClassificationSolution> solutions) { 214 classificationSolutions.RemoveRange(solutions); 215 } 216 217 private void classificationSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) { 218 foreach (var solution in e.Items) AddClassificationSolution(solution); 219 RecalculateResults(); 220 } 221 private void classificationSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) { 222 foreach (var solution in e.Items) RemoveClassificationSolution(solution); 223 RecalculateResults(); 224 } 225 private void classificationSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IClassificationSolution> e) { 226 foreach (var solution in e.OldItems) RemoveClassificationSolution(solution); 227 foreach (var solution in e.Items) AddClassificationSolution(solution); 228 RecalculateResults(); 229 } 230 231 private void AddClassificationSolution(IClassificationSolution solution) { 232 if (Model.Models.Contains(solution.Model)) throw new ArgumentException(); 233 Model.Add(solution.Model); 234 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 235 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 236 } 237 238 private void RemoveClassificationSolution(IClassificationSolution solution) { 239 if (!Model.Models.Contains(solution.Model)) throw new ArgumentException(); 240 Model.Remove(solution.Model); 241 trainingPartitions.Remove(solution.Model); 242 testPartitions.Remove(solution.Model); 243 } 171 244 } 172 245 } -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationProblemData.cs
r6223 r6618 185 185 186 186 #region parameter properties 187 public IValueParameter<StringValue> TargetVariableParameter {188 get { return ( IValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }187 public ConstrainedValueParameter<StringValue> TargetVariableParameter { 188 get { return (ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; } 189 189 } 190 190 public IFixedValueParameter<StringMatrix> ClassNamesParameter { -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs
r6415 r6618 23 23 using System.Linq; 24 24 using HeuristicLab.Common; 25 using HeuristicLab.Data;26 using HeuristicLab.Optimization;27 25 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 26 … … 32 30 /// </summary> 33 31 [StorableClass] 34 public class ClassificationSolution : DataAnalysisSolution, IClassificationSolution { 35 private const string TrainingAccuracyResultName = "Accuracy (training)"; 36 private const string TestAccuracyResultName = "Accuracy (test)"; 37 38 public new IClassificationModel Model { 39 get { return (IClassificationModel)base.Model; } 40 protected set { base.Model = value; } 41 } 42 43 public new IClassificationProblemData ProblemData { 44 get { return (IClassificationProblemData)base.ProblemData; } 45 protected set { base.ProblemData = value; } 46 } 47 48 public double TrainingAccuracy { 49 get { return ((DoubleValue)this[TrainingAccuracyResultName].Value).Value; } 50 private set { ((DoubleValue)this[TrainingAccuracyResultName].Value).Value = value; } 51 } 52 53 public double TestAccuracy { 54 get { return ((DoubleValue)this[TestAccuracyResultName].Value).Value; } 55 private set { ((DoubleValue)this[TestAccuracyResultName].Value).Value = value; } 56 } 32 public abstract class ClassificationSolution : ClassificationSolutionBase { 33 protected readonly Dictionary<int, double> evaluationCache; 57 34 58 35 [StorableConstructor] 59 protected ClassificationSolution(bool deserializing) : base(deserializing) { } 36 protected ClassificationSolution(bool deserializing) 37 : base(deserializing) { 38 evaluationCache = new Dictionary<int, double>(); 39 } 60 40 protected ClassificationSolution(ClassificationSolution original, Cloner cloner) 61 41 : base(original, cloner) { 42 evaluationCache = new Dictionary<int, double>(original.evaluationCache); 62 43 } 63 44 public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData) 64 45 : base(model, problemData) { 65 Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue())); 66 Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue())); 67 CalculateResults(); 46 evaluationCache = new Dictionary<int, double>(); 68 47 } 69 48 70 public override IDeepCloneable Clone(Cloner cloner) { 71 return new ClassificationSolution(this, cloner); 49 public override IEnumerable<double> EstimatedClassValues { 50 get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 51 } 52 public override IEnumerable<double> EstimatedTrainingClassValues { 53 get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); } 54 } 55 public override IEnumerable<double> EstimatedTestClassValues { 56 get { return GetEstimatedClassValues(ProblemData.TestIndizes); } 72 57 } 73 58 74 protected override void RecalculateResults() { 75 CalculateResults(); 59 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 60 var rowsToEvaluate = rows.Except(evaluationCache.Keys); 61 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 62 var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 63 64 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 65 evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 66 } 67 68 return rows.Select(row => evaluationCache[row]); 76 69 } 77 70 78 private void CalculateResults() { 79 double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values 80 IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 81 double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values 82 IEnumerable<double> originalTestClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 83 84 OnlineCalculatorError errorState; 85 double trainingAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTrainingClassValues, originalTrainingClassValues, out errorState); 86 if (errorState != OnlineCalculatorError.None) trainingAccuracy = double.NaN; 87 double testAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTestClassValues, originalTestClassValues, out errorState); 88 if (errorState != OnlineCalculatorError.None) testAccuracy = double.NaN; 89 90 TrainingAccuracy = trainingAccuracy; 91 TestAccuracy = testAccuracy; 71 protected override void OnProblemDataChanged() { 72 evaluationCache.Clear(); 73 base.OnProblemDataChanged(); 92 74 } 93 75 94 public virtual IEnumerable<double> EstimatedClassValues { 95 get { 96 return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); 97 } 98 } 99 100 public virtual IEnumerable<double> EstimatedTrainingClassValues { 101 get { 102 return GetEstimatedClassValues(ProblemData.TrainingIndizes); 103 } 104 } 105 106 public virtual IEnumerable<double> EstimatedTestClassValues { 107 get { 108 return GetEstimatedClassValues(ProblemData.TestIndizes); 109 } 110 } 111 112 public virtual IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 113 return Model.GetEstimatedClassValues(ProblemData.Dataset, rows); 76 protected override void OnModelChanged() { 77 evaluationCache.Clear(); 78 base.OnModelChanged(); 114 79 } 115 80 } -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs
r5809 r6618 33 33 [StorableClass] 34 34 [Item("DiscriminantFunctionClassificationModel", "Represents a classification model that uses a discriminant function and classification thresholds.")] 35 public class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel {35 public abstract class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel { 36 36 [Storable] 37 37 private IRegressionModel model; … … 70 70 } 71 71 72 public override IDeepCloneable Clone(Cloner cloner) {73 return new DiscriminantFunctionClassificationModel(this, cloner);74 }75 76 72 public void SetThresholdsAndClassValues(IEnumerable<double> thresholds, IEnumerable<double> classValues) { 77 73 var classValuesArr = classValues.ToArray(); … … 106 102 } 107 103 #endregion 104 105 public abstract IDiscriminantFunctionClassificationSolution CreateDiscriminantFunctionClassificationSolution(IClassificationProblemData problemData); 106 public abstract IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData); 108 107 } 109 108 } -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs
r6415 r6618 20 20 #endregion 21 21 22 using System;23 22 using System.Collections.Generic; 24 23 using System.Linq; 25 24 using HeuristicLab.Common; 26 25 using HeuristicLab.Core; 27 using HeuristicLab.Data;28 using HeuristicLab.Optimization;29 26 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 30 27 … … 35 32 [StorableClass] 36 33 [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")] 37 public class DiscriminantFunctionClassificationSolution : ClassificationSolution, IDiscriminantFunctionClassificationSolution { 38 private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)"; 39 private const string TestMeanSquaredErrorResultName = "Mean squared error (test)"; 40 private const string TrainingRSquaredResultName = "Pearson's R² (training)"; 41 private const string TestRSquaredResultName = "Pearson's R² (test)"; 34 public abstract class DiscriminantFunctionClassificationSolution : DiscriminantFunctionClassificationSolutionBase { 35 protected readonly Dictionary<int, double> valueEvaluationCache; 36 protected readonly Dictionary<int, double> classValueEvaluationCache; 42 37 43 public new IDiscriminantFunctionClassificationModel Model { 44 get { return (IDiscriminantFunctionClassificationModel)base.Model; } 45 protected set { 46 if (value != null && value != Model) { 47 if (Model != null) { 48 Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged); 49 } 50 value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged); 51 base.Model = value; 52 } 53 } 38 [StorableConstructor] 39 protected DiscriminantFunctionClassificationSolution(bool deserializing) 40 : base(deserializing) { 41 valueEvaluationCache = new Dictionary<int, double>(); 42 classValueEvaluationCache = new Dictionary<int, double>(); 43 } 44 protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner) 45 : base(original, cloner) { 46 valueEvaluationCache = new Dictionary<int, double>(original.valueEvaluationCache); 47 classValueEvaluationCache = new Dictionary<int, double>(original.classValueEvaluationCache); 48 } 49 protected DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData) 50 : base(model, problemData) { 51 valueEvaluationCache = new Dictionary<int, double>(); 52 classValueEvaluationCache = new Dictionary<int, double>(); 53 54 SetAccuracyMaximizingThresholds(); 54 55 } 55 56 56 public double TrainingMeanSquaredError { 57 get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; } 58 private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; } 57 public override IEnumerable<double> EstimatedClassValues { 58 get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 59 } 60 public override IEnumerable<double> EstimatedTrainingClassValues { 61 get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); } 62 } 63 public override IEnumerable<double> EstimatedTestClassValues { 64 get { return GetEstimatedClassValues(ProblemData.TestIndizes); } 59 65 } 60 66 61 public double TestMeanSquaredError { 62 get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; } 63 private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; } 67 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 68 var rowsToEvaluate = rows.Except(classValueEvaluationCache.Keys); 69 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 70 var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 71 72 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 73 classValueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 74 } 75 76 return rows.Select(row => classValueEvaluationCache[row]); 64 77 } 65 78 66 public double TrainingRSquared {67 get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; }68 private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; }69 }70 79 71 public double TestRSquared { 72 get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; } 73 private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; } 74 } 75 76 [StorableConstructor] 77 protected DiscriminantFunctionClassificationSolution(bool deserializing) : base(deserializing) { } 78 protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner) 79 : base(original, cloner) { 80 RegisterEventHandler(); 81 } 82 public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData) 83 : this(new DiscriminantFunctionClassificationModel(model), problemData) { 84 } 85 public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData) 86 : base(model, problemData) { 87 Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue())); 88 Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue())); 89 Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue())); 90 Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue())); 91 SetAccuracyMaximizingThresholds(); 92 93 //mkommend: important to recalculate accuracy because during the calculation before no thresholds were present 94 base.RecalculateResults(); 95 CalculateResults(); 96 RegisterEventHandler(); 97 } 98 99 [StorableHook(HookType.AfterDeserialization)] 100 private void AfterDeserialization() { 101 RegisterEventHandler(); 102 } 103 104 protected override void OnModelChanged(EventArgs e) { 105 DeregisterEventHandler(); 106 SetAccuracyMaximizingThresholds(); 107 RegisterEventHandler(); 108 base.OnModelChanged(e); 109 } 110 111 protected override void RecalculateResults() { 112 base.RecalculateResults(); 113 CalculateResults(); 114 } 115 116 private void CalculateResults() { 117 double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values 118 IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 119 double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values 120 IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 121 122 OnlineCalculatorError errorState; 123 double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 124 TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN; 125 double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 126 TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN; 127 128 double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 129 TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN; 130 double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 131 TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN; 132 } 133 134 private void RegisterEventHandler() { 135 Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged); 136 } 137 private void DeregisterEventHandler() { 138 Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged); 139 } 140 private void Model_ThresholdsChanged(object sender, EventArgs e) { 141 OnModelThresholdsChanged(e); 142 } 143 144 public void SetAccuracyMaximizingThresholds() { 145 double[] classValues; 146 double[] thresholds; 147 var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 148 AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds); 149 150 Model.SetThresholdsAndClassValues(thresholds, classValues); 151 } 152 153 public void SetClassDistibutionCutPointThresholds() { 154 double[] classValues; 155 double[] thresholds; 156 var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 157 NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds); 158 159 Model.SetThresholdsAndClassValues(thresholds, classValues); 160 } 161 162 protected virtual void OnModelThresholdsChanged(EventArgs e) { 163 RecalculateResults(); 164 } 165 166 public IEnumerable<double> EstimatedValues { 80 public override IEnumerable<double> EstimatedValues { 167 81 get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 168 82 } 169 170 public IEnumerable<double> EstimatedTrainingValues { 83 public override IEnumerable<double> EstimatedTrainingValues { 171 84 get { return GetEstimatedValues(ProblemData.TrainingIndizes); } 172 85 } 173 174 public IEnumerable<double> EstimatedTestValues { 86 public override IEnumerable<double> EstimatedTestValues { 175 87 get { return GetEstimatedValues(ProblemData.TestIndizes); } 176 88 } 177 89 178 public IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 179 return Model.GetEstimatedValues(ProblemData.Dataset, rows); 90 public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 91 var rowsToEvaluate = rows.Except(valueEvaluationCache.Keys); 92 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 93 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 94 95 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 96 valueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 97 } 98 99 return rows.Select(row => valueEvaluationCache[row]); 100 } 101 102 protected override void OnModelChanged() { 103 valueEvaluationCache.Clear(); 104 classValueEvaluationCache.Clear(); 105 base.OnModelChanged(); 106 } 107 protected override void OnModelThresholdsChanged(System.EventArgs e) { 108 classValueEvaluationCache.Clear(); 109 base.OnModelThresholdsChanged(e); 110 } 111 protected override void OnProblemDataChanged() { 112 valueEvaluationCache.Clear(); 113 classValueEvaluationCache.Clear(); 114 base.OnProblemDataChanged(); 180 115 } 181 116 } -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisProblem.cs
r5809 r6618 48 48 public T ProblemData { 49 49 get { return ProblemDataParameter.Value; } 50 protected set { ProblemDataParameter.Value = value; } 50 protected set { 51 ProblemDataParameter.Value = value; 52 } 51 53 } 52 54 #endregion 53 55 protected DataAnalysisProblem(DataAnalysisProblem<T> original, Cloner cloner) 54 56 : base(original, cloner) { 57 RegisterEventHandlers(); 55 58 } 56 59 [StorableConstructor] … … 59 62 : base() { 60 63 Parameters.Add(new ValueParameter<T>(ProblemDataParameterName, ProblemDataParameterDescription)); 64 RegisterEventHandlers(); 65 } 66 67 [StorableHook(HookType.AfterDeserialization)] 68 private void AfterDeserialization() { 69 RegisterEventHandlers(); 61 70 } 62 71 63 72 private void RegisterEventHandlers() { 73 ProblemDataParameter.ValueChanged += new EventHandler(ProblemDataParameter_ValueChanged); 74 if (ProblemDataParameter.Value != null) ProblemDataParameter.Value.Changed += new EventHandler(ProblemDataParameter_ValueChanged); 75 } 76 77 private void ProblemDataParameter_ValueChanged(object sender, EventArgs e) { 64 78 ProblemDataParameter.Value.Changed += new EventHandler(ProblemDataParameter_ValueChanged); 65 }66 private void ProblemDataParameter_ValueChanged(object sender, EventArgs e) {67 79 OnProblemDataChanged(); 68 80 OnReset(); -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisProblemData.cs
r6238 r6618 91 91 [StorableConstructor] 92 92 protected DataAnalysisProblemData(bool deserializing) : base(deserializing) { } 93 [StorableHook(HookType.AfterDeserialization)] 94 private void AfterDeserialization() { 95 RegisterEventHandlers(); 96 } 93 97 94 98 protected DataAnalysisProblemData(Dataset dataset, IEnumerable<string> allowedInputVariables) { -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/DataAnalysisSolution.cs
r6415 r6618 48 48 if (value != null) { 49 49 this[ModelResultName].Value = value; 50 OnModelChanged( EventArgs.Empty);50 OnModelChanged(); 51 51 } 52 52 } … … 62 62 this[ProblemDataResultName].Value = value; 63 63 ProblemData.Changed += new EventHandler(ProblemData_Changed); 64 OnProblemDataChanged( EventArgs.Empty);64 OnProblemDataChanged(); 65 65 } 66 66 } … … 80 80 name = ItemName; 81 81 description = ItemDescription; 82 Add(new Result(ModelResultName, "The symbolicdata analysis model.", model));83 Add(new Result(ProblemDataResultName, "The symbolicdata analysis problem data.", problemData));82 Add(new Result(ModelResultName, "The data analysis model.", model)); 83 Add(new Result(ProblemDataResultName, "The data analysis problem data.", problemData)); 84 84 85 85 problemData.Changed += new EventHandler(ProblemData_Changed); … … 89 89 90 90 private void ProblemData_Changed(object sender, EventArgs e) { 91 OnProblemDataChanged( e);91 OnProblemDataChanged(); 92 92 } 93 93 94 94 public event EventHandler ModelChanged; 95 protected virtual void OnModelChanged(EventArgs e) { 95 protected virtual void OnModelChanged() { 96 RecalculateResults(); 96 97 RecalculateResults(); 97 98 var listeners = ModelChanged; 98 if (listeners != null) listeners(this, e);99 if (listeners != null) listeners(this, EventArgs.Empty); 99 100 } 100 101 101 102 public event EventHandler ProblemDataChanged; 102 protected virtual void OnProblemDataChanged(EventArgs e) { 103 protected virtual void OnProblemDataChanged() { 104 RecalculateResults(); 103 105 RecalculateResults(); 104 106 var listeners = ProblemDataChanged; 105 if (listeners != null) listeners(this, e);107 if (listeners != null) listeners(this, EventArgs.Empty); 106 108 } 107 109 -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs
r5809 r6618 34 34 public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel { 35 35 36 [Storable]37 36 private List<IRegressionModel> models; 38 37 public IEnumerable<IRegressionModel> Models { 39 38 get { return new List<IRegressionModel>(models); } 40 39 } 40 41 [Storable(Name = "Models")] 42 private IEnumerable<IRegressionModel> StorableModels { 43 get { return models; } 44 set { models = value.ToList(); } 45 } 46 47 #region backwards compatiblity 3.3.5 48 [Storable(Name = "models", AllowOneWay = true)] 49 private List<IRegressionModel> OldStorableModels { 50 set { models = value; } 51 } 52 #endregion 53 41 54 [StorableConstructor] 42 55 protected RegressionEnsembleModel(bool deserializing) : base(deserializing) { } … … 57 70 58 71 #region IRegressionEnsembleModel Members 72 73 public void Add(IRegressionModel model) { 74 models.Add(model); 75 } 76 public void Remove(IRegressionModel model) { 77 models.Remove(model); 78 } 59 79 60 80 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(Dataset dataset, IEnumerable<int> rows) { … … 79 99 } 80 100 101 public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) { 102 return new RegressionEnsembleSolution(this.Models, problemData); 103 } 104 IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) { 105 return CreateRegressionSolution(problemData); 106 } 107 81 108 #endregion 82 109 } -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs
r6377 r6618 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; 25 using HeuristicLab.Collections; 24 26 using HeuristicLab.Common; 25 27 using HeuristicLab.Core; 28 using HeuristicLab.Data; 26 29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 27 using System;28 using HeuristicLab.Data;29 30 30 31 namespace HeuristicLab.Problems.DataAnalysis { … … 35 36 [Item("Regression Ensemble Solution", "A regression solution that contains an ensemble of multiple regression models")] 36 37 // [Creatable("Data Analysis")] 37 public class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution {38 public sealed class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution { 38 39 public new IRegressionEnsembleModel Model { 39 40 get { return (IRegressionEnsembleModel)base.Model; } 41 } 42 43 private readonly ItemCollection<IRegressionSolution> regressionSolutions; 44 public IItemCollection<IRegressionSolution> RegressionSolutions { 45 get { return regressionSolutions; } 40 46 } 41 47 … … 46 52 47 53 [StorableConstructor] 48 protected RegressionEnsembleSolution(bool deserializing) : base(deserializing) { } 49 protected RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner) 54 private RegressionEnsembleSolution(bool deserializing) 55 : base(deserializing) { 56 regressionSolutions = new ItemCollection<IRegressionSolution>(); 57 } 58 [StorableHook(HookType.AfterDeserialization)] 59 private void AfterDeserialization() { 60 foreach (var model in Model.Models) { 61 IRegressionProblemData problemData = (IRegressionProblemData)ProblemData.Clone(); 62 problemData.TrainingPartition.Start = trainingPartitions[model].Start; 63 problemData.TrainingPartition.End = trainingPartitions[model].End; 64 problemData.TestPartition.Start = testPartitions[model].Start; 65 problemData.TestPartition.End = testPartitions[model].End; 66 67 regressionSolutions.Add(model.CreateRegressionSolution(problemData)); 68 } 69 RegisterRegressionSolutionsEventHandler(); 70 } 71 72 private RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner) 50 73 : base(original, cloner) { 51 74 trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); … … 57 80 testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value); 58 81 } 59 RecalculateResults(); 82 83 regressionSolutions = cloner.Clone(original.regressionSolutions); 84 RegisterRegressionSolutionsEventHandler(); 60 85 } 61 86 62 87 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData) 63 : base(new RegressionEnsembleModel(models), new RegressionEnsembleProblemData(problemData)) { 64 trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 65 testPartitions = new Dictionary<IRegressionModel, IntRange>(); 66 foreach (var model in models) { 67 trainingPartitions[model] = (IntRange)problemData.TrainingPartition.Clone(); 68 testPartitions[model] = (IntRange)problemData.TestPartition.Clone(); 69 } 70 RecalculateResults(); 71 } 88 : this(models, problemData, 89 models.Select(m => (IntRange)problemData.TrainingPartition.Clone()), 90 models.Select(m => (IntRange)problemData.TestPartition.Clone()) 91 ) { } 72 92 73 93 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions) 74 : base(new RegressionEnsembleModel( models), new RegressionEnsembleProblemData(problemData)) {94 : base(new RegressionEnsembleModel(Enumerable.Empty<IRegressionModel>()), new RegressionEnsembleProblemData(problemData)) { 75 95 this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 76 96 this.testPartitions = new Dictionary<IRegressionModel, IntRange>(); 97 this.regressionSolutions = new ItemCollection<IRegressionSolution>(); 98 99 List<IRegressionSolution> solutions = new List<IRegressionSolution>(); 77 100 var modelEnumerator = models.GetEnumerator(); 78 101 var trainingPartitionEnumerator = trainingPartitions.GetEnumerator(); 79 102 var testPartitionEnumerator = testPartitions.GetEnumerator(); 103 80 104 while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) { 81 this.trainingPartitions[modelEnumerator.Current] = (IntRange)trainingPartitionEnumerator.Current.Clone(); 82 this.testPartitions[modelEnumerator.Current] = (IntRange)testPartitionEnumerator.Current.Clone(); 105 var p = (IRegressionProblemData)problemData.Clone(); 106 p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start; 107 p.TrainingPartition.End = trainingPartitionEnumerator.Current.End; 108 p.TestPartition.Start = testPartitionEnumerator.Current.Start; 109 p.TestPartition.End = testPartitionEnumerator.Current.End; 110 111 solutions.Add(modelEnumerator.Current.CreateRegressionSolution(p)); 83 112 } 84 113 if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) { 85 114 throw new ArgumentException(); 86 115 } 87 RecalculateResults(); 116 117 RegisterRegressionSolutionsEventHandler(); 118 regressionSolutions.AddRange(solutions); 88 119 } 89 120 … … 91 122 return new RegressionEnsembleSolution(this, cloner); 92 123 } 93 124 private void RegisterRegressionSolutionsEventHandler() { 125 regressionSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsAdded); 126 regressionSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsRemoved); 127 regressionSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_CollectionReset); 128 } 129 130 protected override void RecalculateResults() { 131 CalculateResults(); 132 } 133 134 #region Evaluation 94 135 public override IEnumerable<double> EstimatedTrainingValues { 95 136 get { … … 160 201 return estimatedValues.DefaultIfEmpty(double.NaN).Average(); 161 202 } 203 #endregion 204 205 public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 206 solutions.OfType<RegressionEnsembleSolution>().SelectMany(ensemble => ensemble.RegressionSolutions); 207 regressionSolutions.AddRange(solutions); 208 } 209 public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 210 regressionSolutions.RemoveRange(solutions); 211 } 212 213 private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 214 foreach (var solution in e.Items) AddRegressionSolution(solution); 215 RecalculateResults(); 216 } 217 private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 218 foreach (var solution in e.Items) RemoveRegressionSolution(solution); 219 RecalculateResults(); 220 } 221 private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 222 foreach (var solution in e.OldItems) RemoveRegressionSolution(solution); 223 foreach (var solution in e.Items) AddRegressionSolution(solution); 224 RecalculateResults(); 225 } 226 227 private void AddRegressionSolution(IRegressionSolution solution) { 228 if (Model.Models.Contains(solution.Model)) throw new ArgumentException(); 229 Model.Add(solution.Model); 230 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 231 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 232 } 233 234 private void RemoveRegressionSolution(IRegressionSolution solution) { 235 if (!Model.Models.Contains(solution.Model)) throw new ArgumentException(); 236 Model.Remove(solution.Model); 237 trainingPartitions.Remove(solution.Model); 238 testPartitions.Remove(solution.Model); 239 } 162 240 } 163 241 } -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionProblemData.cs
r6238 r6618 77 77 #endregion 78 78 79 public IValueParameter<StringValue> TargetVariableParameter {80 get { return ( IValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }79 public ConstrainedValueParameter<StringValue> TargetVariableParameter { 80 get { return (ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; } 81 81 } 82 82 public string TargetVariable { -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolution.cs
r6415 r6618 23 23 using System.Linq; 24 24 using HeuristicLab.Common; 25 using HeuristicLab.Data;26 using HeuristicLab.Optimization;27 25 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 26 … … 32 30 /// </summary> 33 31 [StorableClass] 34 public class RegressionSolution : DataAnalysisSolution, IRegressionSolution { 35 private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)"; 36 private const string TestMeanSquaredErrorResultName = "Mean squared error (test)"; 37 private const string TrainingSquaredCorrelationResultName = "Pearson's R² (training)"; 38 private const string TestSquaredCorrelationResultName = "Pearson's R² (test)"; 39 private const string TrainingRelativeErrorResultName = "Average relative error (training)"; 40 private const string TestRelativeErrorResultName = "Average relative error (test)"; 41 private const string TrainingNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (training)"; 42 private const string TestNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (test)"; 43 44 public new IRegressionModel Model { 45 get { return (IRegressionModel)base.Model; } 46 protected set { base.Model = value; } 47 } 48 49 public new IRegressionProblemData ProblemData { 50 get { return (IRegressionProblemData)base.ProblemData; } 51 protected set { base.ProblemData = value; } 52 } 53 54 public double TrainingMeanSquaredError { 55 get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; } 56 private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; } 57 } 58 59 public double TestMeanSquaredError { 60 get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; } 61 private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; } 62 } 63 64 public double TrainingRSquared { 65 get { return ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value; } 66 private set { ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value = value; } 67 } 68 69 public double TestRSquared { 70 get { return ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value; } 71 private set { ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value = value; } 72 } 73 74 public double TrainingRelativeError { 75 get { return ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value; } 76 private set { ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value = value; } 77 } 78 79 public double TestRelativeError { 80 get { return ((DoubleValue)this[TestRelativeErrorResultName].Value).Value; } 81 private set { ((DoubleValue)this[TestRelativeErrorResultName].Value).Value = value; } 82 } 83 84 public double TrainingNormalizedMeanSquaredError { 85 get { return ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value; } 86 private set { ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value = value; } 87 } 88 89 public double TestNormalizedMeanSquaredError { 90 get { return ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value; } 91 private set { ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value = value; } 92 } 93 32 public abstract class RegressionSolution : RegressionSolutionBase { 33 protected readonly Dictionary<int, double> evaluationCache; 94 34 95 35 [StorableConstructor] 96 protected RegressionSolution(bool deserializing) : base(deserializing) { } 36 protected RegressionSolution(bool deserializing) 37 : base(deserializing) { 38 evaluationCache = new Dictionary<int, double>(); 39 } 97 40 protected RegressionSolution(RegressionSolution original, Cloner cloner) 98 41 : base(original, cloner) { 42 evaluationCache = new Dictionary<int, double>(original.evaluationCache); 99 43 } 100 p ublicRegressionSolution(IRegressionModel model, IRegressionProblemData problemData)44 protected RegressionSolution(IRegressionModel model, IRegressionProblemData problemData) 101 45 : base(model, problemData) { 102 Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue())); 103 Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue())); 104 Add(new Result(TrainingSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue())); 105 Add(new Result(TestSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue())); 106 Add(new Result(TrainingRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the training partition", new PercentValue())); 107 Add(new Result(TestRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the test partition", new PercentValue())); 108 Add(new Result(TrainingNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the training partition", new DoubleValue())); 109 Add(new Result(TestNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the test partition", new DoubleValue())); 110 111 CalculateResults(); 112 } 113 114 public override IDeepCloneable Clone(Cloner cloner) { 115 return new RegressionSolution(this, cloner); 46 evaluationCache = new Dictionary<int, double>(); 116 47 } 117 48 … … 120 51 } 121 52 122 private void CalculateResults() { 123 double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values 124 IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 125 double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values 126 IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 127 128 OnlineCalculatorError errorState; 129 double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 130 TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN; 131 double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 132 TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN; 133 134 double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 135 TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN; 136 double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 137 TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN; 138 139 double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 140 TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN; 141 double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 142 TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN; 143 144 double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 145 TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN; 146 double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 147 TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN; 53 public override IEnumerable<double> EstimatedValues { 54 get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 55 } 56 public override IEnumerable<double> EstimatedTrainingValues { 57 get { return GetEstimatedValues(ProblemData.TrainingIndizes); } 58 } 59 public override IEnumerable<double> EstimatedTestValues { 60 get { return GetEstimatedValues(ProblemData.TestIndizes); } 148 61 } 149 62 150 public virtual IEnumerable<double> EstimatedValues { 151 get { 152 return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); 63 public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 64 var rowsToEvaluate = rows.Except(evaluationCache.Keys); 65 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 66 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 67 68 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 69 evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 153 70 } 71 72 return rows.Select(row => evaluationCache[row]); 154 73 } 155 74 156 public virtual IEnumerable<double> EstimatedTrainingValues { 157 get { 158 return GetEstimatedValues(ProblemData.TrainingIndizes); 159 } 75 protected override void OnProblemDataChanged() { 76 evaluationCache.Clear(); 77 base.OnProblemDataChanged(); 160 78 } 161 79 162 public virtual IEnumerable<double> EstimatedTestValues { 163 get { 164 return GetEstimatedValues(ProblemData.TestIndizes); 165 } 166 } 167 168 public virtual IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 169 return Model.GetEstimatedValues(ProblemData.Dataset, rows); 80 protected override void OnModelChanged() { 81 evaluationCache.Clear(); 82 base.OnModelChanged(); 170 83 } 171 84 }
Note: See TracChangeset
for help on using the changeset viewer.