Changeset 14029 for branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression
- Timestamp:
- 07/08/16 14:40:02 (8 years ago)
- Location:
- branches/crossvalidation-2434
- Files:
-
- 7 edited
- 2 copied
Legend:
- Unmodified
- Added
- Removed
-
branches/crossvalidation-2434
- Property svn:mergeinfo changed
-
branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis
- Property svn:mergeinfo changed
-
branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/ConstantRegressionModel.cs
r12509 r14029 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; 24 25 using HeuristicLab.Common; 25 26 using HeuristicLab.Core; 27 using HeuristicLab.Data; 26 28 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 27 29 … … 29 31 [StorableClass] 30 32 [Item("Constant Regression Model", "A model that always returns the same constant value regardless of the presented input data.")] 31 public class ConstantRegressionModel : NamedItem, IRegressionModel { 33 [Obsolete] 34 public class ConstantRegressionModel : RegressionModel, IStringConvertibleValue { 35 public override IEnumerable<string> VariablesUsedForPrediction { get { return Enumerable.Empty<string>(); } } 36 32 37 [Storable] 33 pr otecteddouble constant;38 private double constant; 34 39 public double Constant { 35 40 get { return constant; } 41 // setter not implemented because manipulation of the constant is not allowed 36 42 } 37 43 … … 42 48 this.constant = original.constant; 43 49 } 50 44 51 public override IDeepCloneable Clone(Cloner cloner) { return new ConstantRegressionModel(this, cloner); } 45 52 46 public ConstantRegressionModel(double constant )47 : base( ) {53 public ConstantRegressionModel(double constant, string targetVariable) 54 : base(targetVariable) { 48 55 this.name = ItemName; 49 56 this.description = ItemDescription; 50 57 this.constant = constant; 58 this.ReadOnly = true; // changing a constant regression model is not supported 51 59 } 52 60 53 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {61 public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 54 62 return rows.Select(row => Constant); 55 63 } 56 64 57 public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {58 return new ConstantRegressionSolution( this, new RegressionProblemData(problemData));65 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 66 return new ConstantRegressionSolution(new ConstantModel(constant, TargetVariable), new RegressionProblemData(problemData)); 59 67 } 68 69 public override string ToString() { 70 return string.Format("Constant: {0}", GetValue()); 71 } 72 73 #region IStringConvertibleValue 74 public bool ReadOnly { get; private set; } 75 public bool Validate(string value, out string errorMessage) { 76 throw new NotSupportedException(); // changing a constant regression model is not supported 77 } 78 79 public string GetValue() { 80 return string.Format("{0:E4}", constant); 81 } 82 83 public bool SetValue(string value) { 84 throw new NotSupportedException(); // changing a constant regression model is not supported 85 } 86 87 public event EventHandler ValueChanged; 88 #endregion 60 89 } 61 90 } -
branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/ConstantRegressionSolution.cs
r12012 r14029 28 28 [Item(Name = "Constant Regression Solution", Description = "Represents a constant regression solution (model + data).")] 29 29 public class ConstantRegressionSolution : RegressionSolution { 30 public new Constant RegressionModel Model {31 get { return (Constant RegressionModel)base.Model; }30 public new ConstantModel Model { 31 get { return (ConstantModel)base.Model; } 32 32 set { base.Model = value; } 33 33 } … … 36 36 protected ConstantRegressionSolution(bool deserializing) : base(deserializing) { } 37 37 protected ConstantRegressionSolution(ConstantRegressionSolution original, Cloner cloner) : base(original, cloner) { } 38 public ConstantRegressionSolution(Constant RegressionModel model, IRegressionProblemData problemData)38 public ConstantRegressionSolution(ConstantModel model, IRegressionProblemData problemData) 39 39 : base(model, problemData) { 40 40 RecalculateResults(); -
branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs
r12509 r14029 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; … … 32 33 [StorableClass] 33 34 [Item("RegressionEnsembleModel", "A regression model that contains an ensemble of multiple regression models")] 34 public class RegressionEnsembleModel : NamedItem, IRegressionEnsembleModel { 35 public sealed class RegressionEnsembleModel : RegressionModel, IRegressionEnsembleModel { 36 public override IEnumerable<string> VariablesUsedForPrediction { 37 get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); } 38 } 35 39 36 40 private List<IRegressionModel> models; … … 45 49 } 46 50 51 private List<double> modelWeights; 52 public IEnumerable<double> ModelWeights { 53 get { return modelWeights; } 54 } 55 56 [Storable(Name = "ModelWeights")] 57 private IEnumerable<double> StorableModelWeights { 58 get { return modelWeights; } 59 set { modelWeights = value.ToList(); } 60 } 61 62 [Storable] 63 private bool averageModelEstimates = true; 64 public bool AverageModelEstimates { 65 get { return averageModelEstimates; } 66 set { 67 if (averageModelEstimates != value) { 68 averageModelEstimates = value; 69 OnChanged(); 70 } 71 } 72 } 73 47 74 #region backwards compatiblity 3.3.5 48 75 [Storable(Name = "models", AllowOneWay = true)] … … 52 79 #endregion 53 80 81 [StorableHook(HookType.AfterDeserialization)] 82 private void AfterDeserialization() { 83 // BackwardsCompatibility 3.3.14 84 #region Backwards compatible code, remove with 3.4 85 if (modelWeights == null || !modelWeights.Any()) 86 modelWeights = new List<double>(models.Select(m => 1.0)); 87 #endregion 88 } 89 54 90 [StorableConstructor] 55 pr otectedRegressionEnsembleModel(bool deserializing) : base(deserializing) { }56 pr otectedRegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner)91 private RegressionEnsembleModel(bool deserializing) : base(deserializing) { } 92 private RegressionEnsembleModel(RegressionEnsembleModel original, Cloner cloner) 57 93 : base(original, cloner) { 58 this.models = original.Models.Select(m => cloner.Clone(m)).ToList(); 94 this.models = original.Models.Select(cloner.Clone).ToList(); 95 this.modelWeights = new List<double>(original.ModelWeights); 96 this.averageModelEstimates = original.averageModelEstimates; 97 } 98 public override IDeepCloneable Clone(Cloner cloner) { 99 return new RegressionEnsembleModel(this, cloner); 59 100 } 60 101 61 102 public RegressionEnsembleModel() : this(Enumerable.Empty<IRegressionModel>()) { } 62 public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) 63 : base() { 103 public RegressionEnsembleModel(IEnumerable<IRegressionModel> models) : this(models, models.Select(m => 1.0)) { } 104 public RegressionEnsembleModel(IEnumerable<IRegressionModel> models, IEnumerable<double> modelWeights) 105 : base(string.Empty) { 64 106 this.name = ItemName; 65 107 this.description = ItemDescription; 108 66 109 this.models = new List<IRegressionModel>(models); 67 } 68 69 public override IDeepCloneable Clone(Cloner cloner) { 70 return new RegressionEnsembleModel(this, cloner); 71 } 72 73 #region IRegressionEnsembleModel Members 110 this.modelWeights = new List<double>(modelWeights); 111 112 if (this.models.Any()) this.TargetVariable = this.models.First().TargetVariable; 113 } 74 114 75 115 public void Add(IRegressionModel model) { 116 if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable; 117 Add(model, 1.0); 118 } 119 public void Add(IRegressionModel model, double weight) { 120 if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable; 121 76 122 models.Add(model); 77 } 123 modelWeights.Add(weight); 124 OnChanged(); 125 } 126 127 public void AddRange(IEnumerable<IRegressionModel> models) { 128 AddRange(models, models.Select(m => 1.0)); 129 } 130 public void AddRange(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) { 131 if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = models.First().TargetVariable; 132 133 this.models.AddRange(models); 134 modelWeights.AddRange(weights); 135 OnChanged(); 136 } 137 78 138 public void Remove(IRegressionModel model) { 79 models.Remove(model); 80 } 81 139 var index = models.IndexOf(model); 140 models.RemoveAt(index); 141 modelWeights.RemoveAt(index); 142 143 if (!models.Any()) TargetVariable = string.Empty; 144 OnChanged(); 145 } 146 public void RemoveRange(IEnumerable<IRegressionModel> models) { 147 foreach (var model in models) { 148 var index = this.models.IndexOf(model); 149 this.models.RemoveAt(index); 150 modelWeights.RemoveAt(index); 151 } 152 153 if (!models.Any()) TargetVariable = string.Empty; 154 OnChanged(); 155 } 156 157 public double GetModelWeight(IRegressionModel model) { 158 var index = models.IndexOf(model); 159 return modelWeights[index]; 160 } 161 public void SetModelWeight(IRegressionModel model, double weight) { 162 var index = models.IndexOf(model); 163 modelWeights[index] = weight; 164 OnChanged(); 165 } 166 167 #region evaluation 82 168 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) { 83 169 var estimatedValuesEnumerators = (from model in models 84 select model.GetEstimatedValues(dataset, rows).GetEnumerator()) 85 .ToList(); 170 let weight = GetModelWeight(model) 171 select model.GetEstimatedValues(dataset, rows).Select(e => weight * e) 172 .GetEnumerator()).ToList(); 86 173 87 174 while (estimatedValuesEnumerators.All(en => en.MoveNext())) { … … 91 178 } 92 179 180 public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 181 double weightsSum = modelWeights.Sum(); 182 var summedEstimates = from estimatedValuesVector in GetEstimatedValueVectors(dataset, rows) 183 select estimatedValuesVector.DefaultIfEmpty(double.NaN).Sum(); 184 185 if (AverageModelEstimates) 186 return summedEstimates.Select(v => v / weightsSum); 187 else 188 return summedEstimates; 189 190 } 191 192 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) { 193 var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator(); 194 var rowsEnumerator = rows.GetEnumerator(); 195 196 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) { 197 var estimatedValueEnumerator = estimatedValuesEnumerators.Current.GetEnumerator(); 198 int currentRow = rowsEnumerator.Current; 199 double weightsSum = 0.0; 200 double filteredEstimatesSum = 0.0; 201 202 for (int m = 0; m < models.Count; m++) { 203 estimatedValueEnumerator.MoveNext(); 204 var model = models[m]; 205 if (!modelSelectionPredicate(currentRow, model)) continue; 206 207 filteredEstimatesSum += estimatedValueEnumerator.Current; 208 weightsSum += modelWeights[m]; 209 } 210 211 if (AverageModelEstimates) 212 yield return filteredEstimatesSum / weightsSum; 213 else 214 yield return filteredEstimatesSum; 215 } 216 } 217 93 218 #endregion 94 219 95 #region IRegressionModel Members 96 97 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 98 foreach (var estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)) { 99 yield return estimatedValuesVector.Average(); 100 } 101 } 102 103 public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) { 104 return new RegressionEnsembleSolution(this.Models, new RegressionEnsembleProblemData(problemData)); 105 } 106 IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) { 107 return CreateRegressionSolution(problemData); 108 } 109 110 #endregion 220 public event EventHandler Changed; 221 private void OnChanged() { 222 var handler = Changed; 223 if (handler != null) 224 handler(this, EventArgs.Empty); 225 } 226 227 228 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 229 return new RegressionEnsembleSolution(this, new RegressionEnsembleProblemData(problemData)); 230 } 111 231 } 112 232 } -
branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs
r12820 r14029 79 79 } 80 80 } 81 82 RegisterModelEvents(); 81 83 RegisterRegressionSolutionsEventHandler(); 82 84 } … … 93 95 } 94 96 97 evaluationCache = new Dictionary<int, double>(original.ProblemData.Dataset.Rows); 95 98 trainingEvaluationCache = new Dictionary<int, double>(original.ProblemData.TrainingIndices.Count()); 96 99 testEvaluationCache = new Dictionary<int, double>(original.ProblemData.TestIndices.Count()); 97 100 98 101 regressionSolutions = cloner.Clone(original.regressionSolutions); 102 RegisterModelEvents(); 99 103 RegisterRegressionSolutionsEventHandler(); 100 104 } … … 106 110 regressionSolutions = new ItemCollection<IRegressionSolution>(); 107 111 112 RegisterModelEvents(); 108 113 RegisterRegressionSolutionsEventHandler(); 109 114 } 110 115 111 116 public RegressionEnsembleSolution(IRegressionProblemData problemData) 112 : this(Enumerable.Empty<IRegressionModel>(), problemData) { 113 } 114 115 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData) 116 : this(models, problemData, 117 models.Select(m => (IntRange)problemData.TrainingPartition.Clone()), 118 models.Select(m => (IntRange)problemData.TestPartition.Clone()) 119 ) { } 120 121 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions) 122 : base(new RegressionEnsembleModel(Enumerable.Empty<IRegressionModel>()), new RegressionEnsembleProblemData(problemData)) { 123 this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 124 this.testPartitions = new Dictionary<IRegressionModel, IntRange>(); 125 this.regressionSolutions = new ItemCollection<IRegressionSolution>(); 126 127 List<IRegressionSolution> solutions = new List<IRegressionSolution>(); 128 var modelEnumerator = models.GetEnumerator(); 129 var trainingPartitionEnumerator = trainingPartitions.GetEnumerator(); 130 var testPartitionEnumerator = testPartitions.GetEnumerator(); 131 132 while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) { 133 var p = (IRegressionProblemData)problemData.Clone(); 134 p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start; 135 p.TrainingPartition.End = trainingPartitionEnumerator.Current.End; 136 p.TestPartition.Start = testPartitionEnumerator.Current.Start; 137 p.TestPartition.End = testPartitionEnumerator.Current.End; 138 139 solutions.Add(modelEnumerator.Current.CreateRegressionSolution(p)); 140 } 141 if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) { 142 throw new ArgumentException(); 143 } 144 117 : this(new RegressionEnsembleModel(), problemData) { 118 } 119 120 public RegressionEnsembleSolution(IRegressionEnsembleModel model, IRegressionProblemData problemData) 121 : base(model, new RegressionEnsembleProblemData(problemData)) { 122 trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 123 testPartitions = new Dictionary<IRegressionModel, IntRange>(); 124 regressionSolutions = new ItemCollection<IRegressionSolution>(); 125 126 evaluationCache = new Dictionary<int, double>(problemData.Dataset.Rows); 145 127 trainingEvaluationCache = new Dictionary<int, double>(problemData.TrainingIndices.Count()); 146 128 testEvaluationCache = new Dictionary<int, double>(problemData.TestIndices.Count()); 147 129 130 131 var solutions = model.Models.Select(m => m.CreateRegressionSolution((IRegressionProblemData)problemData.Clone())); 132 foreach (var solution in solutions) { 133 regressionSolutions.Add(solution); 134 trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition); 135 testPartitions.Add(solution.Model, solution.ProblemData.TestPartition); 136 } 137 138 RecalculateResults(); 139 RegisterModelEvents(); 148 140 RegisterRegressionSolutionsEventHandler(); 149 regressionSolutions.AddRange(solutions);150 } 141 } 142 151 143 152 144 public override IDeepCloneable Clone(Cloner cloner) { 153 145 return new RegressionEnsembleSolution(this, cloner); 146 } 147 148 private void RegisterModelEvents() { 149 Model.Changed += Model_Changed; 154 150 } 155 151 private void RegisterRegressionSolutionsEventHandler() { … … 168 164 var rows = ProblemData.TrainingIndices; 169 165 var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys); 166 170 167 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 171 var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();168 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator(); 172 169 173 170 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { … … 184 181 var rowsToEvaluate = rows.Except(testEvaluationCache.Keys); 185 182 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 186 var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();183 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, RowIsTestForModel).GetEnumerator(); 187 184 188 185 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { … … 193 190 } 194 191 } 195 196 private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {197 var estimatedValuesEnumerators = (from model in Model.Models198 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })199 .ToList();200 var rowsEnumerator = rows.GetEnumerator();201 // aggregate to make sure that MoveNext is called for all enumerators202 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {203 int currentRow = rowsEnumerator.Current;204 205 var selectedEnumerators = from pair in estimatedValuesEnumerators206 where modelSelectionPredicate(currentRow, pair.Model)207 select pair.EstimatedValuesEnumerator;208 209 yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));210 }211 }212 213 192 private bool RowIsTrainingForModel(int currentRow, IRegressionModel model) { 214 193 return trainingPartitions == null || !trainingPartitions.ContainsKey(model) || 215 194 (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End); 216 195 } 217 218 196 private bool RowIsTestForModel(int currentRow, IRegressionModel model) { 219 197 return testPartitions == null || !testPartitions.ContainsKey(model) || … … 224 202 var rowsToEvaluate = rows.Except(evaluationCache.Keys); 225 203 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 226 var valuesEnumerator = (from xs in GetEstimatedValueVectors(ProblemData.Dataset, rowsToEvaluate) 227 select AggregateEstimatedValues(xs)) 228 .GetEnumerator(); 204 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 229 205 230 206 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { … … 235 211 } 236 212 237 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) { 238 if (!Model.Models.Any()) yield break; 239 var estimatedValuesEnumerators = (from model in Model.Models 240 select model.GetEstimatedValues(dataset, rows).GetEnumerator()) 241 .ToList(); 242 243 while (estimatedValuesEnumerators.All(en => en.MoveNext())) { 244 yield return from enumerator in estimatedValuesEnumerators 245 select enumerator.Current; 246 } 247 } 248 249 private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) { 250 return estimatedValues.DefaultIfEmpty(double.NaN).Average(); 213 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IEnumerable<int> rows) { 214 return Model.GetEstimatedValueVectors(ProblemData.Dataset, rows); 251 215 } 252 216 #endregion … … 282 246 } 283 247 284 public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 285 regressionSolutions.AddRange(solutions); 248 private void Model_Changed(object sender, EventArgs e) { 249 var modelSet = new HashSet<IRegressionModel>(Model.Models); 250 foreach (var model in Model.Models) { 251 if (!trainingPartitions.ContainsKey(model)) trainingPartitions.Add(model, ProblemData.TrainingPartition); 252 if (!testPartitions.ContainsKey(model)) testPartitions.Add(model, ProblemData.TrainingPartition); 253 } 254 foreach (var model in trainingPartitions.Keys) { 255 if (modelSet.Contains(model)) continue; 256 trainingPartitions.Remove(model); 257 testPartitions.Remove(model); 258 } 286 259 287 260 trainingEvaluationCache.Clear(); 288 261 testEvaluationCache.Clear(); 289 262 evaluationCache.Clear(); 263 264 OnModelChanged(); 265 } 266 267 public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 268 regressionSolutions.AddRange(solutions); 290 269 } 291 270 public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 292 271 regressionSolutions.RemoveRange(solutions); 293 294 trainingEvaluationCache.Clear();295 testEvaluationCache.Clear();296 evaluationCache.Clear();297 272 } 298 273 299 274 private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 300 foreach (var solution in e.Items) AddRegressionSolution(solution); 301 RecalculateResults(); 275 foreach (var solution in e.Items) { 276 trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition); 277 testPartitions.Add(solution.Model, solution.ProblemData.TestPartition); 278 } 279 Model.AddRange(e.Items.Select(s => s.Model)); 302 280 } 303 281 private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 304 foreach (var solution in e.Items) RemoveRegressionSolution(solution); 305 RecalculateResults(); 282 foreach (var solution in e.Items) { 283 trainingPartitions.Remove(solution.Model); 284 testPartitions.Remove(solution.Model); 285 } 286 Model.RemoveRange(e.Items.Select(s => s.Model)); 306 287 } 307 288 private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 308 foreach (var solution in e.OldItems) RemoveRegressionSolution(solution); 309 foreach (var solution in e.Items) AddRegressionSolution(solution); 310 RecalculateResults(); 311 } 312 313 private void AddRegressionSolution(IRegressionSolution solution) { 314 if (Model.Models.Contains(solution.Model)) throw new ArgumentException(); 315 Model.Add(solution.Model); 316 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 317 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 318 319 trainingEvaluationCache.Clear(); 320 testEvaluationCache.Clear(); 321 evaluationCache.Clear(); 322 } 323 324 private void RemoveRegressionSolution(IRegressionSolution solution) { 325 if (!Model.Models.Contains(solution.Model)) throw new ArgumentException(); 326 Model.Remove(solution.Model); 327 trainingPartitions.Remove(solution.Model); 328 testPartitions.Remove(solution.Model); 329 330 trainingEvaluationCache.Clear(); 331 testEvaluationCache.Clear(); 332 evaluationCache.Clear(); 289 foreach (var solution in e.OldItems) { 290 trainingPartitions.Remove(solution.Model); 291 testPartitions.Remove(solution.Model); 292 } 293 Model.RemoveRange(e.OldItems.Select(s => s.Model)); 294 295 foreach (var solution in e.Items) { 296 trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition); 297 testPartitions.Add(solution.Model, solution.ProblemData.TestPartition); 298 } 299 Model.AddRange(e.Items.Select(s => s.Model)); 333 300 } 334 301 } -
branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionProblemData.cs
r12509 r14029 110 110 } 111 111 112 public IEnumerable<double> TargetVariableValues { 113 get { return Dataset.GetDoubleValues(TargetVariable); } 114 } 115 public IEnumerable<double> TargetVariableTrainingValues { 116 get { return Dataset.GetDoubleValues(TargetVariable, TrainingIndices); } 117 } 118 public IEnumerable<double> TargetVariableTestValues { 119 get { return Dataset.GetDoubleValues(TargetVariable, TestIndices); } 120 } 121 122 112 123 [StorableConstructor] 113 124 protected RegressionProblemData(bool deserializing) : base(deserializing) { }
Note: See TracChangeset
for help on using the changeset viewer.