Changeset 6618 for branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression
- Timestamp:
- 08/01/11 17:48:53 (13 years ago)
- Location:
- branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression
- Files:
-
- 4 edited
- 1 copied
Legend:
- Unmodified
- Added
- Removed
-
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs
r5809 r6618 34 34 public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel { 35 35 36 [Storable]37 36 private List<IRegressionModel> models; 38 37 public IEnumerable<IRegressionModel> Models { 39 38 get { return new List<IRegressionModel>(models); } 40 39 } 40 41 [Storable(Name = "Models")] 42 private IEnumerable<IRegressionModel> StorableModels { 43 get { return models; } 44 set { models = value.ToList(); } 45 } 46 47 #region backwards compatiblity 3.3.5 48 [Storable(Name = "models", AllowOneWay = true)] 49 private List<IRegressionModel> OldStorableModels { 50 set { models = value; } 51 } 52 #endregion 53 41 54 [StorableConstructor] 42 55 protected RegressionEnsembleModel(bool deserializing) : base(deserializing) { } … … 57 70 58 71 #region IRegressionEnsembleModel Members 72 73 public void Add(IRegressionModel model) { 74 models.Add(model); 75 } 76 public void Remove(IRegressionModel model) { 77 models.Remove(model); 78 } 59 79 60 80 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(Dataset dataset, IEnumerable<int> rows) { … … 79 99 } 80 100 101 public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) { 102 return new RegressionEnsembleSolution(this.Models, problemData); 103 } 104 IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) { 105 return CreateRegressionSolution(problemData); 106 } 107 81 108 #endregion 82 109 } -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs
r6377 r6618 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; 25 using HeuristicLab.Collections; 24 26 using HeuristicLab.Common; 25 27 using HeuristicLab.Core; 28 using HeuristicLab.Data; 26 29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 27 using System;28 using HeuristicLab.Data;29 30 30 31 namespace HeuristicLab.Problems.DataAnalysis { … … 35 36 [Item("Regression Ensemble Solution", "A regression solution that contains an ensemble of multiple regression models")] 36 37 // [Creatable("Data Analysis")] 37 public class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution {38 public sealed class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution { 38 39 public new IRegressionEnsembleModel Model { 39 40 get { return (IRegressionEnsembleModel)base.Model; } 41 } 42 43 private readonly ItemCollection<IRegressionSolution> regressionSolutions; 44 public IItemCollection<IRegressionSolution> RegressionSolutions { 45 get { return regressionSolutions; } 40 46 } 41 47 … … 46 52 47 53 [StorableConstructor] 48 protected RegressionEnsembleSolution(bool deserializing) : base(deserializing) { } 49 protected RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner) 54 private RegressionEnsembleSolution(bool deserializing) 55 : base(deserializing) { 56 regressionSolutions = new ItemCollection<IRegressionSolution>(); 57 } 58 [StorableHook(HookType.AfterDeserialization)] 59 private void AfterDeserialization() { 60 foreach (var model in Model.Models) { 61 IRegressionProblemData problemData = (IRegressionProblemData)ProblemData.Clone(); 62 problemData.TrainingPartition.Start = trainingPartitions[model].Start; 63 problemData.TrainingPartition.End = trainingPartitions[model].End; 64 problemData.TestPartition.Start = testPartitions[model].Start; 65 problemData.TestPartition.End = testPartitions[model].End; 66 67 regressionSolutions.Add(model.CreateRegressionSolution(problemData)); 68 } 69 RegisterRegressionSolutionsEventHandler(); 70 } 71 72 private RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner) 50 73 : base(original, cloner) { 51 74 trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); … … 57 80 testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value); 58 81 } 59 RecalculateResults(); 82 83 regressionSolutions = cloner.Clone(original.regressionSolutions); 84 RegisterRegressionSolutionsEventHandler(); 60 85 } 61 86 62 87 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData) 63 : base(new RegressionEnsembleModel(models), new RegressionEnsembleProblemData(problemData)) { 64 trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 65 testPartitions = new Dictionary<IRegressionModel, IntRange>(); 66 foreach (var model in models) { 67 trainingPartitions[model] = (IntRange)problemData.TrainingPartition.Clone(); 68 testPartitions[model] = (IntRange)problemData.TestPartition.Clone(); 69 } 70 RecalculateResults(); 71 } 88 : this(models, problemData, 89 models.Select(m => (IntRange)problemData.TrainingPartition.Clone()), 90 models.Select(m => (IntRange)problemData.TestPartition.Clone()) 91 ) { } 72 92 73 93 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions) 74 : base(new RegressionEnsembleModel( models), new RegressionEnsembleProblemData(problemData)) {94 : base(new RegressionEnsembleModel(Enumerable.Empty<IRegressionModel>()), new RegressionEnsembleProblemData(problemData)) { 75 95 this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 76 96 this.testPartitions = new Dictionary<IRegressionModel, IntRange>(); 97 this.regressionSolutions = new ItemCollection<IRegressionSolution>(); 98 99 List<IRegressionSolution> solutions = new List<IRegressionSolution>(); 77 100 var modelEnumerator = models.GetEnumerator(); 78 101 var trainingPartitionEnumerator = trainingPartitions.GetEnumerator(); 79 102 var testPartitionEnumerator = testPartitions.GetEnumerator(); 103 80 104 while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) { 81 this.trainingPartitions[modelEnumerator.Current] = (IntRange)trainingPartitionEnumerator.Current.Clone(); 82 this.testPartitions[modelEnumerator.Current] = (IntRange)testPartitionEnumerator.Current.Clone(); 105 var p = (IRegressionProblemData)problemData.Clone(); 106 p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start; 107 p.TrainingPartition.End = trainingPartitionEnumerator.Current.End; 108 p.TestPartition.Start = testPartitionEnumerator.Current.Start; 109 p.TestPartition.End = testPartitionEnumerator.Current.End; 110 111 solutions.Add(modelEnumerator.Current.CreateRegressionSolution(p)); 83 112 } 84 113 if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) { 85 114 throw new ArgumentException(); 86 115 } 87 RecalculateResults(); 116 117 RegisterRegressionSolutionsEventHandler(); 118 regressionSolutions.AddRange(solutions); 88 119 } 89 120 … … 91 122 return new RegressionEnsembleSolution(this, cloner); 92 123 } 93 124 private void RegisterRegressionSolutionsEventHandler() { 125 regressionSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsAdded); 126 regressionSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsRemoved); 127 regressionSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_CollectionReset); 128 } 129 130 protected override void RecalculateResults() { 131 CalculateResults(); 132 } 133 134 #region Evaluation 94 135 public override IEnumerable<double> EstimatedTrainingValues { 95 136 get { … … 160 201 return estimatedValues.DefaultIfEmpty(double.NaN).Average(); 161 202 } 203 #endregion 204 205 public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 206 solutions.OfType<RegressionEnsembleSolution>().SelectMany(ensemble => ensemble.RegressionSolutions); 207 regressionSolutions.AddRange(solutions); 208 } 209 public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 210 regressionSolutions.RemoveRange(solutions); 211 } 212 213 private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 214 foreach (var solution in e.Items) AddRegressionSolution(solution); 215 RecalculateResults(); 216 } 217 private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 218 foreach (var solution in e.Items) RemoveRegressionSolution(solution); 219 RecalculateResults(); 220 } 221 private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 222 foreach (var solution in e.OldItems) RemoveRegressionSolution(solution); 223 foreach (var solution in e.Items) AddRegressionSolution(solution); 224 RecalculateResults(); 225 } 226 227 private void AddRegressionSolution(IRegressionSolution solution) { 228 if (Model.Models.Contains(solution.Model)) throw new ArgumentException(); 229 Model.Add(solution.Model); 230 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 231 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 232 } 233 234 private void RemoveRegressionSolution(IRegressionSolution solution) { 235 if (!Model.Models.Contains(solution.Model)) throw new ArgumentException(); 236 Model.Remove(solution.Model); 237 trainingPartitions.Remove(solution.Model); 238 testPartitions.Remove(solution.Model); 239 } 162 240 } 163 241 } -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionProblemData.cs
r6238 r6618 77 77 #endregion 78 78 79 public IValueParameter<StringValue> TargetVariableParameter {80 get { return ( IValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }79 public ConstrainedValueParameter<StringValue> TargetVariableParameter { 80 get { return (ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; } 81 81 } 82 82 public string TargetVariable { -
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolution.cs
r6415 r6618 23 23 using System.Linq; 24 24 using HeuristicLab.Common; 25 using HeuristicLab.Data;26 using HeuristicLab.Optimization;27 25 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 26 … … 32 30 /// </summary> 33 31 [StorableClass] 34 public class RegressionSolution : DataAnalysisSolution, IRegressionSolution { 35 private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)"; 36 private const string TestMeanSquaredErrorResultName = "Mean squared error (test)"; 37 private const string TrainingSquaredCorrelationResultName = "Pearson's R² (training)"; 38 private const string TestSquaredCorrelationResultName = "Pearson's R² (test)"; 39 private const string TrainingRelativeErrorResultName = "Average relative error (training)"; 40 private const string TestRelativeErrorResultName = "Average relative error (test)"; 41 private const string TrainingNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (training)"; 42 private const string TestNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (test)"; 43 44 public new IRegressionModel Model { 45 get { return (IRegressionModel)base.Model; } 46 protected set { base.Model = value; } 47 } 48 49 public new IRegressionProblemData ProblemData { 50 get { return (IRegressionProblemData)base.ProblemData; } 51 protected set { base.ProblemData = value; } 52 } 53 54 public double TrainingMeanSquaredError { 55 get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; } 56 private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; } 57 } 58 59 public double TestMeanSquaredError { 60 get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; } 61 private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; } 62 } 63 64 public double TrainingRSquared { 65 get { return ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value; } 66 private set { ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value = value; } 67 } 68 69 public double TestRSquared { 70 get { return ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value; } 71 private set { ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value = value; } 72 } 73 74 public double TrainingRelativeError { 75 get { return ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value; } 76 private set { ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value = value; } 77 } 78 79 public double TestRelativeError { 80 get { return ((DoubleValue)this[TestRelativeErrorResultName].Value).Value; } 81 private set { ((DoubleValue)this[TestRelativeErrorResultName].Value).Value = value; } 82 } 83 84 public double TrainingNormalizedMeanSquaredError { 85 get { return ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value; } 86 private set { ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value = value; } 87 } 88 89 public double TestNormalizedMeanSquaredError { 90 get { return ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value; } 91 private set { ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value = value; } 92 } 93 32 public abstract class RegressionSolution : RegressionSolutionBase { 33 protected readonly Dictionary<int, double> evaluationCache; 94 34 95 35 [StorableConstructor] 96 protected RegressionSolution(bool deserializing) : base(deserializing) { } 36 protected RegressionSolution(bool deserializing) 37 : base(deserializing) { 38 evaluationCache = new Dictionary<int, double>(); 39 } 97 40 protected RegressionSolution(RegressionSolution original, Cloner cloner) 98 41 : base(original, cloner) { 42 evaluationCache = new Dictionary<int, double>(original.evaluationCache); 99 43 } 100 p ublicRegressionSolution(IRegressionModel model, IRegressionProblemData problemData)44 protected RegressionSolution(IRegressionModel model, IRegressionProblemData problemData) 101 45 : base(model, problemData) { 102 Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue())); 103 Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue())); 104 Add(new Result(TrainingSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue())); 105 Add(new Result(TestSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue())); 106 Add(new Result(TrainingRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the training partition", new PercentValue())); 107 Add(new Result(TestRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the test partition", new PercentValue())); 108 Add(new Result(TrainingNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the training partition", new DoubleValue())); 109 Add(new Result(TestNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the test partition", new DoubleValue())); 110 111 CalculateResults(); 112 } 113 114 public override IDeepCloneable Clone(Cloner cloner) { 115 return new RegressionSolution(this, cloner); 46 evaluationCache = new Dictionary<int, double>(); 116 47 } 117 48 … … 120 51 } 121 52 122 private void CalculateResults() { 123 double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values 124 IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 125 double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values 126 IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 127 128 OnlineCalculatorError errorState; 129 double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 130 TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN; 131 double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 132 TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN; 133 134 double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 135 TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN; 136 double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 137 TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN; 138 139 double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 140 TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN; 141 double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 142 TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN; 143 144 double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 145 TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN; 146 double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 147 TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN; 53 public override IEnumerable<double> EstimatedValues { 54 get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 55 } 56 public override IEnumerable<double> EstimatedTrainingValues { 57 get { return GetEstimatedValues(ProblemData.TrainingIndizes); } 58 } 59 public override IEnumerable<double> EstimatedTestValues { 60 get { return GetEstimatedValues(ProblemData.TestIndizes); } 148 61 } 149 62 150 public virtual IEnumerable<double> EstimatedValues { 151 get { 152 return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); 63 public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 64 var rowsToEvaluate = rows.Except(evaluationCache.Keys); 65 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 66 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 67 68 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 69 evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 153 70 } 71 72 return rows.Select(row => evaluationCache[row]); 154 73 } 155 74 156 public virtual IEnumerable<double> EstimatedTrainingValues { 157 get { 158 return GetEstimatedValues(ProblemData.TrainingIndizes); 159 } 75 protected override void OnProblemDataChanged() { 76 evaluationCache.Clear(); 77 base.OnProblemDataChanged(); 160 78 } 161 79 162 public virtual IEnumerable<double> EstimatedTestValues { 163 get { 164 return GetEstimatedValues(ProblemData.TestIndizes); 165 } 166 } 167 168 public virtual IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 169 return Model.GetEstimatedValues(ProblemData.Dataset, rows); 80 protected override void OnModelChanged() { 81 evaluationCache.Clear(); 82 base.OnModelChanged(); 170 83 } 171 84 }
Note: See TracChangeset
for help on using the changeset viewer.