Changeset 13976 for stable/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs
- Timestamp:
- 07/02/16 09:02:09 (7 years ago)
- Location:
- stable
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
stable
- Property svn:mergeinfo changed
/trunk/sources merged: 13697-13698,13700-13702,13704-13705,13711,13715
- Property svn:mergeinfo changed
-
stable/HeuristicLab.Problems.DataAnalysis
- Property svn:mergeinfo changed
/trunk/sources/HeuristicLab.Problems.DataAnalysis merged: 13697-13698,13700-13702,13704-13705,13715
- Property svn:mergeinfo changed
-
stable/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs
r13049 r13976 79 79 } 80 80 } 81 82 RegisterModelEvents(); 81 83 RegisterRegressionSolutionsEventHandler(); 82 84 } … … 93 95 } 94 96 97 evaluationCache = new Dictionary<int, double>(original.ProblemData.Dataset.Rows); 95 98 trainingEvaluationCache = new Dictionary<int, double>(original.ProblemData.TrainingIndices.Count()); 96 99 testEvaluationCache = new Dictionary<int, double>(original.ProblemData.TestIndices.Count()); 97 100 98 101 regressionSolutions = cloner.Clone(original.regressionSolutions); 102 RegisterModelEvents(); 99 103 RegisterRegressionSolutionsEventHandler(); 100 104 } … … 106 110 regressionSolutions = new ItemCollection<IRegressionSolution>(); 107 111 112 RegisterModelEvents(); 108 113 RegisterRegressionSolutionsEventHandler(); 109 114 } 110 115 111 116 public RegressionEnsembleSolution(IRegressionProblemData problemData) 112 : this(Enumerable.Empty<IRegressionModel>(), problemData) { 113 } 114 115 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData) 116 : this(models, problemData, 117 models.Select(m => (IntRange)problemData.TrainingPartition.Clone()), 118 models.Select(m => (IntRange)problemData.TestPartition.Clone()) 119 ) { } 120 121 public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions) 122 : base(new RegressionEnsembleModel(Enumerable.Empty<IRegressionModel>()), new RegressionEnsembleProblemData(problemData)) { 123 this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 124 this.testPartitions = new Dictionary<IRegressionModel, IntRange>(); 125 this.regressionSolutions = new ItemCollection<IRegressionSolution>(); 126 127 List<IRegressionSolution> solutions = new List<IRegressionSolution>(); 128 var modelEnumerator = models.GetEnumerator(); 129 var trainingPartitionEnumerator = trainingPartitions.GetEnumerator(); 130 var testPartitionEnumerator = testPartitions.GetEnumerator(); 131 132 while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) { 133 var p = (IRegressionProblemData)problemData.Clone(); 134 p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start; 135 p.TrainingPartition.End = trainingPartitionEnumerator.Current.End; 136 p.TestPartition.Start = testPartitionEnumerator.Current.Start; 137 p.TestPartition.End = testPartitionEnumerator.Current.End; 138 139 solutions.Add(modelEnumerator.Current.CreateRegressionSolution(p)); 140 } 141 if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) { 142 throw new ArgumentException(); 143 } 144 117 : this(new RegressionEnsembleModel(), problemData) { 118 } 119 120 public RegressionEnsembleSolution(IRegressionEnsembleModel model, IRegressionProblemData problemData) 121 : base(model, new RegressionEnsembleProblemData(problemData)) { 122 trainingPartitions = new Dictionary<IRegressionModel, IntRange>(); 123 testPartitions = new Dictionary<IRegressionModel, IntRange>(); 124 regressionSolutions = new ItemCollection<IRegressionSolution>(); 125 126 evaluationCache = new Dictionary<int, double>(problemData.Dataset.Rows); 145 127 trainingEvaluationCache = new Dictionary<int, double>(problemData.TrainingIndices.Count()); 146 128 testEvaluationCache = new Dictionary<int, double>(problemData.TestIndices.Count()); 147 129 130 131 var solutions = model.Models.Select(m => m.CreateRegressionSolution((IRegressionProblemData)problemData.Clone())); 132 foreach (var solution in solutions) { 133 regressionSolutions.Add(solution); 134 trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition); 135 testPartitions.Add(solution.Model, solution.ProblemData.TestPartition); 136 } 137 138 RecalculateResults(); 139 RegisterModelEvents(); 148 140 RegisterRegressionSolutionsEventHandler(); 149 regressionSolutions.AddRange(solutions);150 } 141 } 142 151 143 152 144 public override IDeepCloneable Clone(Cloner cloner) { 153 145 return new RegressionEnsembleSolution(this, cloner); 146 } 147 148 private void RegisterModelEvents() { 149 Model.Changed += Model_Changed; 154 150 } 155 151 private void RegisterRegressionSolutionsEventHandler() { … … 168 164 var rows = ProblemData.TrainingIndices; 169 165 var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys); 166 170 167 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 171 var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();168 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator(); 172 169 173 170 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { … … 184 181 var rowsToEvaluate = rows.Except(testEvaluationCache.Keys); 185 182 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 186 var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();183 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, RowIsTestForModel).GetEnumerator(); 187 184 188 185 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { … … 193 190 } 194 191 } 195 196 private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {197 var estimatedValuesEnumerators = (from model in Model.Models198 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })199 .ToList();200 var rowsEnumerator = rows.GetEnumerator();201 // aggregate to make sure that MoveNext is called for all enumerators202 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {203 int currentRow = rowsEnumerator.Current;204 205 var selectedEnumerators = from pair in estimatedValuesEnumerators206 where modelSelectionPredicate(currentRow, pair.Model)207 select pair.EstimatedValuesEnumerator;208 209 yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));210 }211 }212 213 192 private bool RowIsTrainingForModel(int currentRow, IRegressionModel model) { 214 193 return trainingPartitions == null || !trainingPartitions.ContainsKey(model) || 215 194 (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End); 216 195 } 217 218 196 private bool RowIsTestForModel(int currentRow, IRegressionModel model) { 219 197 return testPartitions == null || !testPartitions.ContainsKey(model) || … … 224 202 var rowsToEvaluate = rows.Except(evaluationCache.Keys); 225 203 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 226 var valuesEnumerator = (from xs in GetEstimatedValueVectors(ProblemData.Dataset, rowsToEvaluate) 227 select AggregateEstimatedValues(xs)) 228 .GetEnumerator(); 204 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 229 205 230 206 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { … … 235 211 } 236 212 237 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) { 238 if (!Model.Models.Any()) yield break; 239 var estimatedValuesEnumerators = (from model in Model.Models 240 select model.GetEstimatedValues(dataset, rows).GetEnumerator()) 241 .ToList(); 242 243 while (estimatedValuesEnumerators.All(en => en.MoveNext())) { 244 yield return from enumerator in estimatedValuesEnumerators 245 select enumerator.Current; 246 } 247 } 248 249 private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) { 250 return estimatedValues.DefaultIfEmpty(double.NaN).Average(); 213 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IEnumerable<int> rows) { 214 return Model.GetEstimatedValueVectors(ProblemData.Dataset, rows); 251 215 } 252 216 #endregion … … 282 246 } 283 247 284 public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 285 regressionSolutions.AddRange(solutions); 248 private void Model_Changed(object sender, EventArgs e) { 249 var modelSet = new HashSet<IRegressionModel>(Model.Models); 250 foreach (var model in Model.Models) { 251 if (!trainingPartitions.ContainsKey(model)) trainingPartitions.Add(model, ProblemData.TrainingPartition); 252 if (!testPartitions.ContainsKey(model)) testPartitions.Add(model, ProblemData.TrainingPartition); 253 } 254 foreach (var model in trainingPartitions.Keys) { 255 if (modelSet.Contains(model)) continue; 256 trainingPartitions.Remove(model); 257 testPartitions.Remove(model); 258 } 286 259 287 260 trainingEvaluationCache.Clear(); 288 261 testEvaluationCache.Clear(); 289 262 evaluationCache.Clear(); 263 264 OnModelChanged(); 265 } 266 267 public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 268 regressionSolutions.AddRange(solutions); 290 269 } 291 270 public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) { 292 271 regressionSolutions.RemoveRange(solutions); 293 294 trainingEvaluationCache.Clear();295 testEvaluationCache.Clear();296 evaluationCache.Clear();297 272 } 298 273 299 274 private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 300 foreach (var solution in e.Items) AddRegressionSolution(solution); 301 RecalculateResults(); 275 foreach (var solution in e.Items) { 276 trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition); 277 testPartitions.Add(solution.Model, solution.ProblemData.TestPartition); 278 } 279 Model.AddRange(e.Items.Select(s => s.Model)); 302 280 } 303 281 private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 304 foreach (var solution in e.Items) RemoveRegressionSolution(solution); 305 RecalculateResults(); 282 foreach (var solution in e.Items) { 283 trainingPartitions.Remove(solution.Model); 284 testPartitions.Remove(solution.Model); 285 } 286 Model.RemoveRange(e.Items.Select(s => s.Model)); 306 287 } 307 288 private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) { 308 foreach (var solution in e.OldItems) RemoveRegressionSolution(solution); 309 foreach (var solution in e.Items) AddRegressionSolution(solution); 310 RecalculateResults(); 311 } 312 313 private void AddRegressionSolution(IRegressionSolution solution) { 314 if (Model.Models.Contains(solution.Model)) throw new ArgumentException(); 315 Model.Add(solution.Model); 316 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 317 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 318 319 trainingEvaluationCache.Clear(); 320 testEvaluationCache.Clear(); 321 evaluationCache.Clear(); 322 } 323 324 private void RemoveRegressionSolution(IRegressionSolution solution) { 325 if (!Model.Models.Contains(solution.Model)) throw new ArgumentException(); 326 Model.Remove(solution.Model); 327 trainingPartitions.Remove(solution.Model); 328 testPartitions.Remove(solution.Model); 329 330 trainingEvaluationCache.Clear(); 331 testEvaluationCache.Clear(); 332 evaluationCache.Clear(); 289 foreach (var solution in e.OldItems) { 290 trainingPartitions.Remove(solution.Model); 291 testPartitions.Remove(solution.Model); 292 } 293 Model.RemoveRange(e.OldItems.Select(s => s.Model)); 294 295 foreach (var solution in e.Items) { 296 trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition); 297 testPartitions.Add(solution.Model, solution.ProblemData.TestPartition); 298 } 299 Model.AddRange(e.Items.Select(s => s.Model)); 333 300 } 334 301 }
Note: See TracChangeset
for help on using the changeset viewer.