Changeset 6760 for branches/PersistenceSpeedUp/HeuristicLab.Algorithms.DataAnalysis/3.4/CrossValidation.cs
- Timestamp:
- 09/14/11 13:59:25 (13 years ago)
- Location:
- branches/PersistenceSpeedUp
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/PersistenceSpeedUp
- Property svn:ignore
-
old new 12 12 *.psess 13 13 *.vsp 14 *.docstates
-
- Property svn:mergeinfo changed
- Property svn:ignore
-
branches/PersistenceSpeedUp/HeuristicLab.Algorithms.DataAnalysis/3.4/CrossValidation.cs
r6184 r6760 34 34 35 35 namespace HeuristicLab.Algorithms.DataAnalysis { 36 [Item("Cross Validation", "Cross Validation wrapper for data analysis algorithms.")]36 [Item("Cross Validation", "Cross-validation wrapper for data analysis algorithms.")] 37 37 [Creatable("Data Analysis")] 38 38 [StorableClass] … … 363 363 364 364 public void CollectResultValues(IDictionary<string, IItem> results) { 365 var clonedResults = (ResultCollection)this.results.Clone(); 366 foreach (var result in clonedResults) { 367 results.Add(result.Name, result.Value); 368 } 369 } 370 371 private void AggregateResultValues(IDictionary<string, IItem> results) { 365 372 Dictionary<string, List<double>> resultValues = new Dictionary<string, List<double>>(); 366 373 IEnumerable<IRun> runs = clonedAlgorithms.Select(alg => alg.Runs.FirstOrDefault()).Where(run => run != null); … … 374 381 results.Add(result.Name, result.Value); 375 382 foreach (IResult result in ExtractAndAggregateRegressionSolutions(resultCollections)) { 383 results.Add(result.Name, result.Value); 384 } 385 foreach (IResult result in ExtractAndAggregateClassificationSolutions(resultCollections)) { 376 386 results.Add(result.Name, result.Value); 377 387 } … … 394 404 List<IResult> aggregatedResults = new List<IResult>(); 395 405 foreach (KeyValuePair<string, List<IRegressionSolution>> solutions in resultSolutions) { 396 var problemDataClone = (IRegressionProblemData)Problem.ProblemData.Clone(); 406 // clone manually to correctly clone references between cloned root objects 407 Cloner cloner = new Cloner(); 408 var problemDataClone = (IRegressionProblemData)cloner.Clone(Problem.ProblemData); 409 // set partitions of problem data clone correctly 397 410 problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value; 398 411 problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value; 399 var ensembleSolution = new RegressionEnsembleSolution(solutions.Value.Select(x => x.Model), problemDataClone, 400 solutions.Value.Select(x => x.ProblemData.TrainingPartition), 401 solutions.Value.Select(x => x.ProblemData.TestPartition)); 402 403 aggregatedResults.Add(new Result(solutions.Key, ensembleSolution)); 404 } 405 return aggregatedResults; 412 // clone models 413 var ensembleSolution = new RegressionEnsembleSolution( 414 solutions.Value.Select(x => cloner.Clone(x.Model)), 415 problemDataClone, 416 solutions.Value.Select(x => cloner.Clone(x.ProblemData.TrainingPartition)), 417 solutions.Value.Select(x => cloner.Clone(x.ProblemData.TestPartition))); 418 419 aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution)); 420 } 421 List<IResult> flattenedResults = new List<IResult>(); 422 CollectResultsRecursively("", aggregatedResults, flattenedResults); 423 return flattenedResults; 424 } 425 426 private IEnumerable<IResult> ExtractAndAggregateClassificationSolutions(IEnumerable<KeyValuePair<string, IItem>> resultCollections) { 427 Dictionary<string, List<IClassificationSolution>> resultSolutions = new Dictionary<string, List<IClassificationSolution>>(); 428 foreach (var result in resultCollections) { 429 var classificationSolution = result.Value as IClassificationSolution; 430 if (classificationSolution != null) { 431 if (resultSolutions.ContainsKey(result.Key)) { 432 resultSolutions[result.Key].Add(classificationSolution); 433 } else { 434 resultSolutions.Add(result.Key, new List<IClassificationSolution>() { classificationSolution }); 435 } 436 } 437 } 438 var aggregatedResults = new List<IResult>(); 439 foreach (KeyValuePair<string, List<IClassificationSolution>> solutions in resultSolutions) { 440 // clone manually to correctly clone references between cloned root objects 441 Cloner cloner = new Cloner(); 442 var problemDataClone = (IClassificationProblemData)cloner.Clone(Problem.ProblemData); 443 // set partitions of problem data clone correctly 444 problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value; 445 problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value; 446 // clone models 447 var ensembleSolution = new ClassificationEnsembleSolution( 448 solutions.Value.Select(x => cloner.Clone(x.Model)), 449 problemDataClone, 450 solutions.Value.Select(x => cloner.Clone(x.ProblemData.TrainingPartition)), 451 solutions.Value.Select(x => cloner.Clone(x.ProblemData.TestPartition))); 452 453 aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution)); 454 } 455 List<IResult> flattenedResults = new List<IResult>(); 456 CollectResultsRecursively("", aggregatedResults, flattenedResults); 457 return flattenedResults; 458 } 459 460 private void CollectResultsRecursively(string path, IEnumerable<IResult> results, IList<IResult> flattenedResults) { 461 foreach (IResult result in results) { 462 flattenedResults.Add(new Result(path + result.Name, result.Value)); 463 ResultCollection childCollection = result.Value as ResultCollection; 464 if (childCollection != null) { 465 CollectResultsRecursively(path + result.Name + ".", childCollection, flattenedResults); 466 } 467 } 406 468 } 407 469 … … 428 490 foreach (KeyValuePair<string, List<double>> resultValue in resultValues) { 429 491 doubleValue.Value = resultValue.Value.Average(); 430 aggregatedResults.Add(new Result(resultValue.Key , (IItem)doubleValue.Clone()));492 aggregatedResults.Add(new Result(resultValue.Key + " (average)", (IItem)doubleValue.Clone())); 431 493 doubleValue.Value = resultValue.Value.StandardDeviation(); 432 aggregatedResults.Add(new Result(resultValue.Key + " StdDev", (IItem)doubleValue.Clone()));494 aggregatedResults.Add(new Result(resultValue.Key + " (std.dev.)", (IItem)doubleValue.Clone())); 433 495 } 434 496 return aggregatedResults; … … 481 543 throw new ArgumentException("A cross validation algorithm can only contain DataAnalysisProblems."); 482 544 } 545 algorithm.Problem.Reset += (x,y) => OnProblemChanged(); 483 546 problem = (IDataAnalysisProblem)algorithm.Problem; 484 547 OnProblemChanged(); … … 491 554 SamplesStart.Value = 0; 492 555 if (Problem != null) { 493 Problem.ProblemDataChanged += (object sender, EventArgs e) => OnProblemChanged();494 556 SamplesEnd.Value = Problem.ProblemData.Dataset.Rows; 495 557 … … 510 572 } else 511 573 SamplesEnd.Value = 0; 574 575 SamplesStart_ValueChanged(this, EventArgs.Empty); 576 SamplesEnd_ValueChanged(this, EventArgs.Empty); 512 577 } 513 578 … … 656 721 stopPending = false; 657 722 Dictionary<string, IItem> collectedResults = new Dictionary<string, IItem>(); 658 CollectResultValues(collectedResults);723 AggregateResultValues(collectedResults); 659 724 results.AddRange(collectedResults.Select(x => new Result(x.Key, x.Value)).Cast<IResult>().ToArray()); 660 725 runsCounter++;
Note: See TracChangeset
for help on using the changeset viewer.