Changeset 8153
- Timestamp:
- 06/28/12 16:25:58 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
r8139 r8153 37 37 [Creatable("Data Analysis - Ensembles")] 38 38 public sealed class ClassificationEnsembleSolution : ClassificationSolution, IClassificationEnsembleSolution { 39 private readonly Dictionary<int, double> trainingEstimatedValuesCache = new Dictionary<int, double>(); 40 private readonly Dictionary<int, double> testEstimatedValuesCache = new Dictionary<int, double>(); 41 private readonly Dictionary<int, double> estimatedValuesCache = new Dictionary<int, double>(); 42 39 43 public new IClassificationEnsembleModel Model { 40 44 get { return (IClassificationEnsembleModel)base.Model; } … … 149 153 get { 150 154 var rows = ProblemData.TrainingIndices; 151 var estimatedValuesEnumerators = (from model in Model.Models 152 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 153 .ToList(); 154 var rowsEnumerator = rows.GetEnumerator(); 155 // aggregate to make sure that MoveNext is called for all enumerators 156 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 157 int currentRow = rowsEnumerator.Current; 158 159 var selectedEnumerators = from pair in estimatedValuesEnumerators 160 where RowIsTrainingForModel(currentRow, pair.Model) && !RowIsTestForModel(currentRow, pair.Model) 161 select pair.EstimatedValuesEnumerator; 162 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current)); 155 var rowsToEvaluate = rows.Except(trainingEstimatedValuesCache.Keys); 156 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 157 var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator(); 158 159 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 160 trainingEstimatedValuesCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 163 161 } 162 163 return rows.Select(row => trainingEstimatedValuesCache[row]); 164 164 } 165 165 } … … 168 168 get { 169 169 var rows = ProblemData.TestIndices; 170 var estimatedValuesEnumerators = (from model in Model.Models 171 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 172 .ToList(); 173 var rowsEnumerator = ProblemData.TestIndices.GetEnumerator(); 174 // aggregate to make sure that MoveNext is called for all enumerators 175 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 176 int currentRow = rowsEnumerator.Current; 177 178 var selectedEnumerators = from pair in estimatedValuesEnumerators 179 where RowIsTestForModel(currentRow, pair.Model) 180 select pair.EstimatedValuesEnumerator; 181 182 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current)); 170 var rowsToEvaluate = rows.Except(testEstimatedValuesCache.Keys); 171 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 172 var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator(); 173 174 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 175 testEstimatedValuesCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 183 176 } 177 178 return rows.Select(row => testEstimatedValuesCache[row]); 179 } 180 } 181 182 private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IClassificationModel, bool> modelSelectionPredicate) { 183 var estimatedValuesEnumerators = (from model in Model.Models 184 select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedClassValues(ProblemData.Dataset, rows).GetEnumerator() }) 185 .ToList(); 186 var rowsEnumerator = rows.GetEnumerator(); 187 // aggregate to make sure that MoveNext is called for all enumerators 188 while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) { 189 int currentRow = rowsEnumerator.Current; 190 191 var selectedEnumerators = from pair in estimatedValuesEnumerators 192 where modelSelectionPredicate(currentRow, pair.Model) 193 select pair.EstimatedValuesEnumerator; 194 195 yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current)); 184 196 } 185 197 } … … 196 208 197 209 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 198 return from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rows) 199 select AggregateEstimatedClassValues(xs); 210 var rowsToEvaluate = rows.Except(estimatedValuesCache.Keys); 211 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 212 var valuesEnumerator = (from xs in GetEstimatedClassValueVectors(ProblemData.Dataset, rowsToEvaluate) 213 select AggregateEstimatedClassValues(xs)) 214 .GetEnumerator(); 215 216 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 217 estimatedValuesCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 218 } 219 220 return rows.Select(row => estimatedValuesCache[row]); 200 221 } 201 222 … … 223 244 224 245 protected override void OnProblemDataChanged() { 246 trainingEstimatedValuesCache.Clear(); 247 testEstimatedValuesCache.Clear(); 248 estimatedValuesCache.Clear(); 249 225 250 IClassificationProblemData problemData = new ClassificationProblemData(ProblemData.Dataset, 226 251 ProblemData.AllowedInputVariables, … … 251 276 public void AddClassificationSolutions(IEnumerable<IClassificationSolution> solutions) { 252 277 classificationSolutions.AddRange(solutions); 278 279 trainingEstimatedValuesCache.Clear(); 280 testEstimatedValuesCache.Clear(); 281 estimatedValuesCache.Clear(); 253 282 } 254 283 public void RemoveClassificationSolutions(IEnumerable<IClassificationSolution> solutions) { 255 284 classificationSolutions.RemoveRange(solutions); 285 286 trainingEstimatedValuesCache.Clear(); 287 testEstimatedValuesCache.Clear(); 288 estimatedValuesCache.Clear(); 256 289 } 257 290 … … 275 308 trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition; 276 309 testPartitions[solution.Model] = solution.ProblemData.TestPartition; 310 311 trainingEstimatedValuesCache.Clear(); 312 testEstimatedValuesCache.Clear(); 313 estimatedValuesCache.Clear(); 277 314 } 278 315 … … 282 319 trainingPartitions.Remove(solution.Model); 283 320 testPartitions.Remove(solution.Model); 321 322 trainingEstimatedValuesCache.Clear(); 323 testEstimatedValuesCache.Clear(); 324 estimatedValuesCache.Clear(); 284 325 } 285 326 }
Note: See TracChangeset
for help on using the changeset viewer.