- Timestamp:
- 03/14/16 17:16:12 (9 years ago)
- Location:
- trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs
r12509 r13697 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; … … 91 92 } 92 93 94 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) { 95 var estimatedValuesEnumerators = GetEstimatedValueVectors(dataset, rows).GetEnumerator(); 96 var rowsEnumerator = rows.GetEnumerator(); 97 98 // aggregate to make sure that MoveNext is called for all enumerators 99 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.MoveNext()) { 100 int currentRow = rowsEnumerator.Current; 101 102 var filteredEstimates = models.Zip(estimatedValuesEnumerators.Current, 103 (m, e) => new { Model = m, EstimatedValue = e }).Where(f => modelSelectionPredicate(currentRow, f.Model)); 104 105 yield return filteredEstimates.Select(f => f.EstimatedValue).DefaultIfEmpty(double.NaN).Average(); 106 } 107 } 108 93 109 #endregion 94 110 95 111 #region IRegressionModel Members 96 97 112 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 98 113 foreach (var estimatedValuesVector in GetEstimatedValueVectors(dataset, rows)) { … … 107 122 return CreateRegressionSolution(problemData); 108 123 } 109 110 124 #endregion 111 125 } -
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs
r12816 r13697 169 169 var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys); 170 170 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 171 var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();171 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator(); 172 172 173 173 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { … … 184 184 var rowsToEvaluate = rows.Except(testEvaluationCache.Keys); 185 185 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 186 var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();186 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, RowIsTestForModel).GetEnumerator(); 187 187 188 188 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { … … 193 193 } 194 194 } 195 196 private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {197 var estimatedValuesEnumerators = (from model in Model.Models198 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })199 .ToList();200 var rowsEnumerator = rows.GetEnumerator();201 // aggregate to make sure that MoveNext is called for all enumerators202 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {203 int currentRow = rowsEnumerator.Current;204 205 var selectedEnumerators = from pair in estimatedValuesEnumerators206 where modelSelectionPredicate(currentRow, pair.Model)207 select pair.EstimatedValuesEnumerator;208 209 yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));210 }211 }212 213 195 private bool RowIsTrainingForModel(int currentRow, IRegressionModel model) { 214 196 return trainingPartitions == null || !trainingPartitions.ContainsKey(model) || 215 197 (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End); 216 198 } 217 218 199 private bool RowIsTestForModel(int currentRow, IRegressionModel model) { 219 200 return testPartitions == null || !testPartitions.ContainsKey(model) || … … 224 205 var rowsToEvaluate = rows.Except(evaluationCache.Keys); 225 206 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 226 var valuesEnumerator = (from xs in GetEstimatedValueVectors(ProblemData.Dataset, rowsToEvaluate) 227 select AggregateEstimatedValues(xs)) 228 .GetEnumerator(); 207 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 229 208 230 209 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { … … 235 214 } 236 215 237 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) { 238 if (!Model.Models.Any()) yield break; 239 var estimatedValuesEnumerators = (from model in Model.Models 240 select model.GetEstimatedValues(dataset, rows).GetEnumerator()) 241 .ToList(); 242 243 while (estimatedValuesEnumerators.All(en => en.MoveNext())) { 244 yield return from enumerator in estimatedValuesEnumerators 245 select enumerator.Current; 246 } 247 } 248 249 private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) { 250 return estimatedValues.DefaultIfEmpty(double.NaN).Average(); 216 public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IEnumerable<int> rows) { 217 return Model.GetEstimatedValueVectors(ProblemData.Dataset, rows); 251 218 } 252 219 #endregion
Note: See TracChangeset
for help on using the changeset viewer.