Changeset 8151
- Timestamp:
- 06/28/12 15:55:16 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs
r8139 r8151 37 37 [Creatable("Data Analysis - Ensembles")] 38 38 public sealed class RegressionEnsembleSolution : RegressionSolution, IRegressionEnsembleSolution { 39 private readonly Dictionary<int, double> trainingEstimatedValuesCache = new Dictionary<int, double>(); 40 private readonly Dictionary<int, double> testEstimatedValuesCache = new Dictionary<int, double>(); 41 private readonly Dictionary<int, double> estimatedValuesCache = new Dictionary<int, double>(); 42 39 43 public new IRegressionEnsembleModel Model { 40 44 get { return (IRegressionEnsembleModel)base.Model; } … … 152 156 #region Evaluation 153 157 public override IEnumerable<double> EstimatedTrainingValues { 154 get { 155 var rows = ProblemData.TrainingIndices; 156 var estimatedValuesEnumerators = (from model in Model.Models 157 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() }) 158 .ToList(); 159 var rowsEnumerator = rows.GetEnumerator(); 160 // aggregate to make sure that MoveNext is called for all enumerators 161 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 162 int currentRow = rowsEnumerator.Current; 163 164 var selectedEnumerators = from pair in estimatedValuesEnumerators 165 where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model) 166 select pair.EstimatedValuesEnumerator; 167 yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current)); 168 } 169 } 158 get { return GetEstimatedValues(ProblemData.TrainingIndices, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)); } 170 159 } 171 160 172 161 public override IEnumerable<double> EstimatedTestValues { 173 get { 174 var rows = ProblemData.TestIndices; 175 var estimatedValuesEnumerators = (from model in Model.Models 176 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() }) 177 .ToList(); 178 var rowsEnumerator = ProblemData.TestIndices.GetEnumerator(); 179 // aggregate to make sure that MoveNext is called for all enumerators 180 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 181 int currentRow = rowsEnumerator.Current; 182 183 var selectedEnumerators = from pair in estimatedValuesEnumerators 184 where RowIsTestForModel(currentRow, pair.Model) 185 select pair.EstimatedValuesEnumerator; 186 187 yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current)); 188 } 162 get { return GetEstimatedValues(ProblemData.TestIndices, RowIsTestForModel); } 163 } 164 165 private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) { 166 var estimatedValuesEnumerators = (from model in Model.Models 167 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() }) 168 .ToList(); 169 var rowsEnumerator = rows.GetEnumerator(); 170 // aggregate to make sure that MoveNext is called for all enumerators 171 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 172 int currentRow = rowsEnumerator.Current; 173 174 var selectedEnumerators = from pair in estimatedValuesEnumerators 175 where modelSelectionPredicate(currentRow, pair.Model) 176 select pair.EstimatedValuesEnumerator; 177 178 yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current)); 189 179 } 190 180 }
Note: See TracChangeset
for help on using the changeset viewer.