Changeset 6760 for branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression
- Timestamp:
- 09/14/11 13:59:25 (13 years ago)
- Location:
- branches/PersistenceSpeedUp
- Files:
-
- 6 edited
- 2 copied
Legend:
- Unmodified
- Added
- Removed
-
branches/PersistenceSpeedUp
- Property svn:ignore
-
old new 12 12 *.psess 13 13 *.vsp 14 *.docstates
-
- Property svn:mergeinfo changed
- Property svn:ignore
-
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis
- Property svn:mergeinfo changed
-
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs
r5809 r6760 34 34 public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel { 35 35 36 [Storable]37 36 private List<IRegressionModel> models; 38 37 public IEnumerable<IRegressionModel> Models { 39 38 get { return new List<IRegressionModel>(models); } 40 39 } 40 41 [Storable(Name = "Models")] 42 private IEnumerable<IRegressionModel> StorableModels { 43 get { return models; } 44 set { models = value.ToList(); } 45 } 46 47 #region backwards compatiblity 3.3.5 48 [Storable(Name = "models", AllowOneWay = true)] 49 private List<IRegressionModel> OldStorableModels { 50 set { models = value; } 51 } 52 #endregion 53 41 54 [StorableConstructor] 42 55 protected RegressionEnsembleModel(bool deserializing) : base(deserializing) { } … … 45 58 this.models = original.Models.Select(m => cloner.Clone(m)).ToList(); 46 59 } 60 61 public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { } 47 62 public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) 48 63 : base() { … … 57 72 58 73 #region IRegressionEnsembleModel Members 74 75 public void Add(IRegressionModel model) { 76 models.Add(model); 77 } 78 public void Remove(IRegressionModel model) { 79 models.Remove(model); 80 } 59 81 60 82 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(Dataset dataset, IEnumerable<int> rows) { … … 79 101 } 80 102 103 public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) { 104 return new RegressionEnsembleSolution(this.Models, problemData); 105 } 106 IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) { 107 return CreateRegressionSolution(problemData); 108 } 109 81 110 #endregion 82 111 } -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs
r6184 r6760 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; 25 using HeuristicLab.Collections; 24 26 using HeuristicLab.Common; 25 27 using HeuristicLab.Core; 28 using HeuristicLab.Data; 26 29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 27 using System;28 using HeuristicLab.Data;29 30 30 31 namespace HeuristicLab.Problems.DataAnalysis { … … 34 35 [StorableClass] 35 36 [Item("Regression Ensemble Solution", "A regression solution that contains an ensemble of multiple regression models")] 36 // [Creatable("Data Analysis")]37 public class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution {37 [Creatable("Data Analysis - Ensembles")] 38 public sealed class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution { 38 39 public new IRegressionEnsembleModel Model { 39 40 get { return (IRegressionEnsembleModel)base.Model; } 41 } 42 43 public new RegressionEnsembleProblemData ProblemData { 44 get { return (RegressionEnsembleProblemData)base.ProblemData; } 45 set { base.ProblemData = value; } 46 } 47 48 private readonly ItemCollection<IRegressionSolution> regressionSolutions; 49 public IItemCollection<IRegressionSolution> RegressionSolutions { 50 get { return regressionSolutions; } 40 51 } 41 52 … … 46 57 47 58 [StorableConstructor] 48 protected RegressionEnsembleSolution(bool deserializing) : base(deserializing) { } 49 protected RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner) 59 private RegressionEnsembleSolution(bool deserializing) 60 : base(deserializing) { 61 regressionSolutions = new ItemCollection<IRegressionSolution>(); 62 } 63 [StorableHook(HookType.AfterDeserialization)] 64 private void AfterDeserialization() { 65 foreach (var model in Model.Models) { 66 IRegressionProblemData problemData = (IRegressionProblemData) ProblemData.Clone(); 67 problemData.TrainingPartition.Start = trainingPartitions[model].Start; 68 problemData.TrainingPartition.End = trainingPartitions[model].End; 69 problemData.TestPartition.Start = testPartitions[model].Start; 70 problemData.TestPartition.End = testPartitions[model].End; 71 72 regressionSolutions.Add(model.CreateRegressionSolution(problemData)); 73 } 74 RegisterRegressionSolutionsEventHandler(); 75 } 76 77 private RegressionEnsembleSolution(RegressionEnsembleSolution original, Cloner cloner) 50 78 : base(original, cloner) { 51 }52 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)53 : base(new RegressionEnsembleModel(models), problemData) {54 79 trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 55 80 testPartitions = new Dictionary<IRegressionModel, IntRange>(); 56 foreach (var model in models) { 57 trainingPartitions[model] = (IntRange)problemData.TrainingPartition.Clone(); 58 testPartitions[model] = (IntRange)problemData.TestPartition.Clone(); 59 } 60 RecalculateResults(); 61 } 81 foreach (var pair in original.trainingPartitions) { 82 trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value); 83 } 84 foreach (var pair in original.testPartitions) { 85 testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value); 86 } 87 88 regressionSolutions = cloner.Clone(original.regressionSolutions); 89 RegisterRegressionSolutionsEventHandler(); 90 } 91 92 public RegressionEnsembleSolution() 93 : base(new RegressionEnsembleModel(), RegressionEnsembleProblemData.EmptyProblemData) { 94 trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 95 testPartitions = new Dictionary<IRegressionModel, IntRange>(); 96 regressionSolutions = new ItemCollection<IRegressionSolution>(); 97 98 RegisterRegressionSolutionsEventHandler(); 99 } 100 101 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData) 102 : this(models, problemData, 103 models.Select(m => (IntRange)problemData.TrainingPartition.Clone()), 104 models.Select(m => (IntRange)problemData.TestPartition.Clone()) 105 ) { } 62 106 63 107 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions) 64 : base(new RegressionEnsembleModel( models), problemData) {108 : base(new RegressionEnsembleModel(Enumerable.Empty<IRegressionModel>()), new RegressionEnsembleProblemData(problemData)) { 65 109 this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 66 110 this.testPartitions = new Dictionary<IRegressionModel, IntRange>(); 111 this.regressionSolutions = new ItemCollection<IRegressionSolution>(); 112 113 List<IRegressionSolution> solutions = new List<IRegressionSolution>(); 67 114 var modelEnumerator = models.GetEnumerator(); 68 115 var trainingPartitionEnumerator = trainingPartitions.GetEnumerator(); 69 116 var testPartitionEnumerator = testPartitions.GetEnumerator(); 117 70 118 while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) { 71 this.trainingPartitions[modelEnumerator.Current] = (IntRange)trainingPartitionEnumerator.Current.Clone(); 72 this.testPartitions[modelEnumerator.Current] = (IntRange)testPartitionEnumerator.Current.Clone(); 119 var p = (IRegressionProblemData)problemData.Clone(); 120 p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start; 121 p.TrainingPartition.End = trainingPartitionEnumerator.Current.End; 122 p.TestPartition.Start = testPartitionEnumerator.Current.Start; 123 p.TestPartition.End = testPartitionEnumerator.Current.End; 124 125 solutions.Add(modelEnumerator.Current.CreateRegressionSolution(p)); 73 126 } 74 127 if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) { … … 76 129 } 77 130 78 RecalculateResults(); 79 } 80 81 private void RecalculateResults() { 82 double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values 83 var trainingIndizes = Enumerable.Range(ProblemData.TrainingPartition.Start, 84 ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start); 85 IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, trainingIndizes); 86 double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values 87 IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 88 89 OnlineCalculatorError errorState; 90 double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 91 TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN; 92 double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 93 TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN; 94 95 double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 96 TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN; 97 double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 98 TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN; 99 100 double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 101 TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN; 102 double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 103 TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN; 104 105 double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 106 TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN; 107 double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 108 TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN; 131 RegisterRegressionSolutionsEventHandler(); 132 regressionSolutions.AddRange(solutions); 109 133 } 110 134 … … 112 136 return new RegressionEnsembleSolution(this, cloner); 113 137 } 114 138 private void RegisterRegressionSolutionsEventHandler() { 139 regressionSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsAdded); 140 regressionSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_ItemsRemoved); 141 regressionSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IRegressionSolution>(regressionSolutions_CollectionReset); 142 } 143 144 protected override void RecalculateResults() { 145 CalculateResults(); 146 } 147 148 #region Evaluation 115 149 public override IEnumerable<double> EstimatedTrainingValues { 116 150 get { 117 var rows = Enumerable.Range(ProblemData.TrainingPartition.Start, ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);151 var rows = ProblemData.TrainingIndizes; 118 152 var estimatedValuesEnumerators = (from model in Model.Models 119 153 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() }) 120 154 .ToList(); 121 155 var rowsEnumerator = rows.GetEnumerator(); 156 // aggregate to make sure that MoveNext is called for all enumerators 122 157 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 123 158 int currentRow = rowsEnumerator.Current; 124 159 125 160 var selectedEnumerators = from pair in estimatedValuesEnumerators 126 where trainingPartitions == null || !trainingPartitions.ContainsKey(pair.Model) || 127 (trainingPartitions[pair.Model].Start <= currentRow && currentRow < trainingPartitions[pair.Model].End) 161 where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model) 128 162 select pair.EstimatedValuesEnumerator; 129 163 yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current)); … … 134 168 public override IEnumerable<double> EstimatedTestValues { 135 169 get { 170 var rows = ProblemData.TestIndizes; 136 171 var estimatedValuesEnumerators = (from model in Model.Models 137 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, ProblemData.TestIndizes).GetEnumerator() })172 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() }) 138 173 .ToList(); 139 174 var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator(); 175 // aggregate to make sure that MoveNext is called for all enumerators 140 176 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 141 177 int currentRow = rowsEnumerator.Current; 142 178 143 179 var selectedEnumerators = from pair in estimatedValuesEnumerators 144 where testPartitions == null || !testPartitions.ContainsKey(pair.Model) || 145 (testPartitions[pair.Model].Start <= currentRow && currentRow < testPartitions[pair.Model].End) 180 where RowIsTestForModel(currentRow, pair.Model) 146 181 select pair.EstimatedValuesEnumerator; 147 182 … … 149 184 } 150 185 } 186 } 187 188 private bool RowIsTrainingForModel(int currentRow, IRegressionModel model) { 189 return trainingPartitions == null || !trainingPartitions.ContainsKey(model) || 190 (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End); 191 } 192 193 private bool RowIsTestForModel(int currentRow, IRegressionModel model) { 194 return testPartitions == null || !testPartitions.ContainsKey(model) || 195 (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End); 151 196 } 152 197 … … 168 213 169 214 private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) { 170 return estimatedValues.Average(); 171 } 172 173 //[Storable] 174 //private string name; 175 //public string Name { 176 // get { 177 // return name; 178 // } 179 // set { 180 // if (value != null && value != name) { 181 // var cancelEventArgs = new CancelEventArgs<string>(value); 182 // OnNameChanging(cancelEventArgs); 183 // if (cancelEventArgs.Cancel == false) { 184 // name = value; 185 // OnNamedChanged(EventArgs.Empty); 186 // } 187 // } 188 // } 189 //} 190 191 //public bool CanChangeName { 192 // get { return true; } 193 //} 194 195 //[Storable] 196 //private string description; 197 //public string Description { 198 // get { 199 // return description; 200 // } 201 // set { 202 // if (value != null && value != description) { 203 // description = value; 204 // OnDescriptionChanged(EventArgs.Empty); 205 // } 206 // } 207 //} 208 209 //public bool CanChangeDescription { 210 // get { return true; } 211 //} 212 213 //#region events 214 //public event EventHandler<CancelEventArgs<string>> NameChanging; 215 //private void OnNameChanging(CancelEventArgs<string> cancelEventArgs) { 216 // var listener = NameChanging; 217 // if (listener != null) listener(this, cancelEventArgs); 218 //} 219 220 //public event EventHandler NameChanged; 221 //private void OnNamedChanged(EventArgs e) { 222 // var listener = NameChanged; 223 // if (listener != null) listener(this, e); 224 //} 225 226 //public event EventHandler DescriptionChanged; 227 //private void OnDescriptionChanged(EventArgs e) { 228 // var listener = DescriptionChanged; 229 // if (listener != null) listener(this, e); 230 //} 231 // #endregion 215 return estimatedValues.DefaultIfEmpty(double.NaN).Average(); 216 } 217 #endregion 218 219 protected override void OnProblemDataChanged() { 220 IRegressionProblemData problemData = new RegressionProblemData(ProblemData.Dataset, 221 ProblemData.AllowedInputVariables, 222 ProblemData.TargetVariable); 223 problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start; 224 problemData.TrainingPartition.End = ProblemData.TrainingPartition.End; 225 problemData.TestPartition.Start = ProblemData.TestPartition.Start; 226 problemData.TestPartition.End = ProblemData.TestPartition.End; 227 228 foreach (var solution in RegressionSolutions) { 229 if (solution is RegressionEnsembleSolution) 230 solution.ProblemData = ProblemData; 231 else 232 solution.ProblemData = problemData; 233 } 234 foreach (var trainingPartition in trainingPartitions.Values) { 235 trainingPartition.Start = ProblemData.TrainingPartition.Start; 236 trainingPartition.End = ProblemData.TrainingPartition.End; 237 } 238 foreach (var testPartition in testPartitions.Values) { 239 testPartition.Start = ProblemData.TestPartition.Start; 240 testPartition.End = ProblemData.TestPartition.End; 241 } 242 243 base.OnProblemDataChanged(); 244 } 245 246 public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 247 regressionSolutions.AddRange(solutions); 248 } 249 public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 250 regressionSolutions.RemoveRange(solutions); 251 } 252 253 private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 254 foreach (var solution in e.Items) AddRegressionSolution(solution); 255 RecalculateResults(); 256 } 257 private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 258 foreach (var solution in e.Items) RemoveRegressionSolution(solution); 259 RecalculateResults(); 260 } 261 private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 262 foreach (var solution in e.OldItems) RemoveRegressionSolution(solution); 263 foreach (var solution in e.Items) AddRegressionSolution(solution); 264 RecalculateResults(); 265 } 266 267 private void AddRegressionSolution(IRegressionSolution solution) { 268 if (Model.Models.Contains(solution.Model)) throw new ArgumentException(); 269 Model.Add(solution.Model); 270 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 271 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 272 } 273 274 private void RemoveRegressionSolution(IRegressionSolution solution) { 275 if (!Model.Models.Contains(solution.Model)) throw new ArgumentException(); 276 Model.Remove(solution.Model); 277 trainingPartitions.Remove(solution.Model); 278 testPartitions.Remove(solution.Model); 279 } 232 280 } 233 281 } -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionProblemData.cs
r5809 r6760 33 33 [StorableClass] 34 34 [Item("RegressionProblemData", "Represents an item containing all data defining a regression problem.")] 35 public sealedclass RegressionProblemData : DataAnalysisProblemData, IRegressionProblemData {36 pr ivateconst string TargetVariableParameterName = "TargetVariable";35 public class RegressionProblemData : DataAnalysisProblemData, IRegressionProblemData { 36 protected const string TargetVariableParameterName = "TargetVariable"; 37 37 38 38 #region default data … … 64 64 {0.83763905, 0.468046718} 65 65 }; 66 private static Dataset defaultDataset; 67 private static IEnumerable<string> defaultAllowedInputVariables; 68 private static string defaultTargetVariable; 66 private static readonly Dataset defaultDataset; 67 private static readonly IEnumerable<string> defaultAllowedInputVariables; 68 private static readonly string defaultTargetVariable; 69 70 private static readonly RegressionProblemData emptyProblemData; 71 public static RegressionProblemData EmptyProblemData { 72 get { return emptyProblemData; } 73 } 69 74 70 75 static RegressionProblemData() { … … 74 79 defaultAllowedInputVariables = new List<string>() { "x" }; 75 80 defaultTargetVariable = "y"; 81 82 var problemData = new RegressionProblemData(); 83 problemData.Parameters.Clear(); 84 problemData.Name = "Empty Regression ProblemData"; 85 problemData.Description = "This ProblemData acts as place holder before the correct problem data is loaded."; 86 problemData.isEmpty = true; 87 88 problemData.Parameters.Add(new FixedValueParameter<Dataset>(DatasetParameterName, "", new Dataset())); 89 problemData.Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemList<StringValue>>(InputVariablesParameterName, "")); 90 problemData.Parameters.Add(new FixedValueParameter<IntRange>(TrainingPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly())); 91 problemData.Parameters.Add(new FixedValueParameter<IntRange>(TestPartitionParameterName, "", (IntRange)new IntRange(0, 0).AsReadOnly())); 92 problemData.Parameters.Add(new ConstrainedValueParameter<StringValue>(TargetVariableParameterName, new ItemSet<StringValue>())); 93 emptyProblemData = problemData; 76 94 } 77 95 #endregion 78 96 79 public IValueParameter<StringValue> TargetVariableParameter {80 get { return ( IValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }97 public ConstrainedValueParameter<StringValue> TargetVariableParameter { 98 get { return (ConstrainedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; } 81 99 } 82 100 public string TargetVariable { … … 85 103 86 104 [StorableConstructor] 87 pr ivateRegressionProblemData(bool deserializing) : base(deserializing) { }105 protected RegressionProblemData(bool deserializing) : base(deserializing) { } 88 106 [StorableHook(HookType.AfterDeserialization)] 89 107 private void AfterDeserialization() { … … 91 109 } 92 110 93 94 private RegressionProblemData(RegressionProblemData original, Cloner cloner) 111 protected RegressionProblemData(RegressionProblemData original, Cloner cloner) 95 112 : base(original, cloner) { 96 113 RegisterParameterEvents(); 97 114 } 98 public override IDeepCloneable Clone(Cloner cloner) { return new RegressionProblemData(this, cloner); } 115 public override IDeepCloneable Clone(Cloner cloner) { 116 if (this == emptyProblemData) return emptyProblemData; 117 return new RegressionProblemData(this, cloner); 118 } 99 119 100 120 public RegressionProblemData() … … 124 144 dataset.Name = Path.GetFileName(fileName); 125 145 126 RegressionProblemData problemData = new RegressionProblemData(dataset, dataset. VariableNames.Skip(1), dataset.VariableNames.First());146 RegressionProblemData problemData = new RegressionProblemData(dataset, dataset.DoubleVariables.Skip(1), dataset.DoubleVariables.First()); 127 147 problemData.Name = "Data imported from " + Path.GetFileName(fileName); 128 148 return problemData; -
branches/PersistenceSpeedUp/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolution.cs
r6184 r6760 20 20 #endregion 21 21 22 using System;23 22 using System.Collections.Generic; 24 23 using System.Linq; 25 24 using HeuristicLab.Common; 26 using HeuristicLab.Data;27 using HeuristicLab.Optimization;28 25 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 29 26 … … 33 30 /// </summary> 34 31 [StorableClass] 35 public class RegressionSolution : DataAnalysisSolution, IRegressionSolution { 36 private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)"; 37 private const string TestMeanSquaredErrorResultName = "Mean squared error (test)"; 38 private const string TrainingSquaredCorrelationResultName = "Pearson's R² (training)"; 39 private const string TestSquaredCorrelationResultName = "Pearson's R² (test)"; 40 private const string TrainingRelativeErrorResultName = "Average relative error (training)"; 41 private const string TestRelativeErrorResultName = "Average relative error (test)"; 42 private const string TrainingNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (training)"; 43 private const string TestNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (test)"; 32 public abstract class RegressionSolution : RegressionSolutionBase { 33 protected readonly Dictionary<int, double> evaluationCache; 44 34 45 public new IRegressionModel Model { 46 get { return (IRegressionModel)base.Model; } 47 protected set { base.Model = value; } 35 [StorableConstructor] 36 protected RegressionSolution(bool deserializing) 37 : base(deserializing) { 38 evaluationCache = new Dictionary<int, double>(); 39 } 40 protected RegressionSolution(RegressionSolution original, Cloner cloner) 41 : base(original, cloner) { 42 evaluationCache = new Dictionary<int, double>(original.evaluationCache); 43 } 44 protected RegressionSolution(IRegressionModel model, IRegressionProblemData problemData) 45 : base(model, problemData) { 46 evaluationCache = new Dictionary<int, double>(); 48 47 } 49 48 50 public new IRegressionProblemData ProblemData { 51 get { return (IRegressionProblemData)base.ProblemData; } 52 protected set { base.ProblemData = value; } 49 protected override void RecalculateResults() { 50 CalculateResults(); 53 51 } 54 52 55 public double TrainingMeanSquaredError { 56 get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; } 57 protected set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; } 53 public override IEnumerable<double> EstimatedValues { 54 get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 55 } 56 public override IEnumerable<double> EstimatedTrainingValues { 57 get { return GetEstimatedValues(ProblemData.TrainingIndizes); } 58 } 59 public override IEnumerable<double> EstimatedTestValues { 60 get { return GetEstimatedValues(ProblemData.TestIndizes); } 58 61 } 59 62 60 public double TestMeanSquaredError { 61 get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; } 62 protected set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; } 63 public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 64 var rowsToEvaluate = rows.Except(evaluationCache.Keys); 65 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 66 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 67 68 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 69 evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 70 } 71 72 return rows.Select(row => evaluationCache[row]); 63 73 } 64 74 65 p ublic double TrainingRSquared{66 get { return ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value; }67 protected set { ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value = value; }75 protected override void OnProblemDataChanged() { 76 evaluationCache.Clear(); 77 base.OnProblemDataChanged(); 68 78 } 69 79 70 public double TestRSquared { 71 get { return ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value; } 72 protected set { ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value = value; } 73 } 74 75 public double TrainingRelativeError { 76 get { return ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value; } 77 protected set { ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value = value; } 78 } 79 80 public double TestRelativeError { 81 get { return ((DoubleValue)this[TestRelativeErrorResultName].Value).Value; } 82 protected set { ((DoubleValue)this[TestRelativeErrorResultName].Value).Value = value; } 83 } 84 85 public double TrainingNormalizedMeanSquaredError { 86 get { return ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value; } 87 protected set { ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value = value; } 88 } 89 90 public double TestNormalizedMeanSquaredError { 91 get { return ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value; } 92 protected set { ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value = value; } 93 } 94 95 96 [StorableConstructor] 97 protected RegressionSolution(bool deserializing) : base(deserializing) { } 98 protected RegressionSolution(RegressionSolution original, Cloner cloner) 99 : base(original, cloner) { 100 } 101 public RegressionSolution(IRegressionModel model, IRegressionProblemData problemData) 102 : base(model, problemData) { 103 Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue())); 104 Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue())); 105 Add(new Result(TrainingSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue())); 106 Add(new Result(TestSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue())); 107 Add(new Result(TrainingRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the training partition", new PercentValue())); 108 Add(new Result(TestRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the test partition", new PercentValue())); 109 Add(new Result(TrainingNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the training partition", new DoubleValue())); 110 Add(new Result(TestNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the test partition", new DoubleValue())); 111 112 RecalculateResults(); 113 } 114 115 public override IDeepCloneable Clone(Cloner cloner) { 116 return new RegressionSolution(this, cloner); 117 } 118 119 protected override void OnProblemDataChanged(EventArgs e) { 120 base.OnProblemDataChanged(e); 121 RecalculateResults(); 122 } 123 protected override void OnModelChanged(EventArgs e) { 124 base.OnModelChanged(e); 125 RecalculateResults(); 126 } 127 128 private void RecalculateResults() { 129 double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values 130 IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 131 double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values 132 IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 133 134 OnlineCalculatorError errorState; 135 double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 136 TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN; 137 double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 138 TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN; 139 140 double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 141 TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN; 142 double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 143 TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN; 144 145 double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 146 TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN; 147 double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 148 TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN; 149 150 double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 151 TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN; 152 double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 153 TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN; 154 } 155 156 public virtual IEnumerable<double> EstimatedValues { 157 get { 158 return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); 159 } 160 } 161 162 public virtual IEnumerable<double> EstimatedTrainingValues { 163 get { 164 return GetEstimatedValues(ProblemData.TrainingIndizes); 165 } 166 } 167 168 public virtual IEnumerable<double> EstimatedTestValues { 169 get { 170 return GetEstimatedValues(ProblemData.TestIndizes); 171 } 172 } 173 174 public virtual IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 175 return Model.GetEstimatedValues(ProblemData.Dataset, rows); 80 protected override void OnModelChanged() { 81 evaluationCache.Clear(); 82 base.OnModelChanged(); 176 83 } 177 84 }
Note: See TracChangeset
for help on using the changeset viewer.