Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/14/11 13:59:25 (13 years ago)
Author:
epitzer
Message:

#1530 integrate changes from trunk

Location:
branches/PersistenceSpeedUp
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/PersistenceSpeedUp

  • branches/PersistenceSpeedUp/HeuristicLab.Algorithms.DataAnalysis/3.4/CrossValidation.cs

    r6184 r6760  
    3434
    3535namespace HeuristicLab.Algorithms.DataAnalysis {
    36   [Item("Cross Validation", "Cross Validation wrapper for data analysis algorithms.")]
     36  [Item("Cross Validation", "Cross-validation wrapper for data analysis algorithms.")]
    3737  [Creatable("Data Analysis")]
    3838  [StorableClass]
     
    363363
    364364    public void CollectResultValues(IDictionary<string, IItem> results) {
     365      var clonedResults = (ResultCollection)this.results.Clone();
     366      foreach (var result in clonedResults) {
     367        results.Add(result.Name, result.Value);
     368      }
     369    }
     370
     371    private void AggregateResultValues(IDictionary<string, IItem> results) {
    365372      Dictionary<string, List<double>> resultValues = new Dictionary<string, List<double>>();
    366373      IEnumerable<IRun> runs = clonedAlgorithms.Select(alg => alg.Runs.FirstOrDefault()).Where(run => run != null);
     
    374381        results.Add(result.Name, result.Value);
    375382      foreach (IResult result in ExtractAndAggregateRegressionSolutions(resultCollections)) {
     383        results.Add(result.Name, result.Value);
     384      }
     385      foreach (IResult result in ExtractAndAggregateClassificationSolutions(resultCollections)) {
    376386        results.Add(result.Name, result.Value);
    377387      }
     
    394404      List<IResult> aggregatedResults = new List<IResult>();
    395405      foreach (KeyValuePair<string, List<IRegressionSolution>> solutions in resultSolutions) {
    396         var problemDataClone = (IRegressionProblemData)Problem.ProblemData.Clone();
     406        // clone manually to correctly clone references between cloned root objects
     407        Cloner cloner = new Cloner();
     408        var problemDataClone = (IRegressionProblemData)cloner.Clone(Problem.ProblemData);
     409        // set partitions of problem data clone correctly
    397410        problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;
    398411        problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value;
    399         var ensembleSolution = new RegressionEnsembleSolution(solutions.Value.Select(x => x.Model), problemDataClone,
    400           solutions.Value.Select(x => x.ProblemData.TrainingPartition),
    401           solutions.Value.Select(x => x.ProblemData.TestPartition));
    402 
    403         aggregatedResults.Add(new Result(solutions.Key, ensembleSolution));
    404       }
    405       return aggregatedResults;
     412        // clone models
     413        var ensembleSolution = new RegressionEnsembleSolution(
     414          solutions.Value.Select(x => cloner.Clone(x.Model)),
     415          problemDataClone,
     416          solutions.Value.Select(x => cloner.Clone(x.ProblemData.TrainingPartition)),
     417          solutions.Value.Select(x => cloner.Clone(x.ProblemData.TestPartition)));
     418
     419        aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution));
     420      }
     421      List<IResult> flattenedResults = new List<IResult>();
     422      CollectResultsRecursively("", aggregatedResults, flattenedResults);
     423      return flattenedResults;
     424    }
     425
     426    private IEnumerable<IResult> ExtractAndAggregateClassificationSolutions(IEnumerable<KeyValuePair<string, IItem>> resultCollections) {
     427      Dictionary<string, List<IClassificationSolution>> resultSolutions = new Dictionary<string, List<IClassificationSolution>>();
     428      foreach (var result in resultCollections) {
     429        var classificationSolution = result.Value as IClassificationSolution;
     430        if (classificationSolution != null) {
     431          if (resultSolutions.ContainsKey(result.Key)) {
     432            resultSolutions[result.Key].Add(classificationSolution);
     433          } else {
     434            resultSolutions.Add(result.Key, new List<IClassificationSolution>() { classificationSolution });
     435          }
     436        }
     437      }
     438      var aggregatedResults = new List<IResult>();
     439      foreach (KeyValuePair<string, List<IClassificationSolution>> solutions in resultSolutions) {
     440        // clone manually to correctly clone references between cloned root objects
     441        Cloner cloner = new Cloner();
     442        var problemDataClone = (IClassificationProblemData)cloner.Clone(Problem.ProblemData);
     443        // set partitions of problem data clone correctly
     444        problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;
     445        problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value;
     446        // clone models
     447        var ensembleSolution = new ClassificationEnsembleSolution(
     448          solutions.Value.Select(x => cloner.Clone(x.Model)),
     449          problemDataClone,
     450          solutions.Value.Select(x => cloner.Clone(x.ProblemData.TrainingPartition)),
     451          solutions.Value.Select(x => cloner.Clone(x.ProblemData.TestPartition)));
     452
     453        aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution));
     454      }
     455      List<IResult> flattenedResults = new List<IResult>();
     456      CollectResultsRecursively("", aggregatedResults, flattenedResults);
     457      return flattenedResults;
     458    }
     459
     460    private void CollectResultsRecursively(string path, IEnumerable<IResult> results, IList<IResult> flattenedResults) {
     461      foreach (IResult result in results) {
     462        flattenedResults.Add(new Result(path + result.Name, result.Value));
     463        ResultCollection childCollection = result.Value as ResultCollection;
     464        if (childCollection != null) {
     465          CollectResultsRecursively(path + result.Name + ".", childCollection, flattenedResults);
     466        }
     467      }
    406468    }
    407469
     
    428490      foreach (KeyValuePair<string, List<double>> resultValue in resultValues) {
    429491        doubleValue.Value = resultValue.Value.Average();
    430         aggregatedResults.Add(new Result(resultValue.Key, (IItem)doubleValue.Clone()));
     492        aggregatedResults.Add(new Result(resultValue.Key + " (average)", (IItem)doubleValue.Clone()));
    431493        doubleValue.Value = resultValue.Value.StandardDeviation();
    432         aggregatedResults.Add(new Result(resultValue.Key + " StdDev", (IItem)doubleValue.Clone()));
     494        aggregatedResults.Add(new Result(resultValue.Key + " (std.dev.)", (IItem)doubleValue.Clone()));
    433495      }
    434496      return aggregatedResults;
     
    481543        throw new ArgumentException("A cross validation algorithm can only contain DataAnalysisProblems.");
    482544      }
     545      algorithm.Problem.Reset += (x,y) => OnProblemChanged();
    483546      problem = (IDataAnalysisProblem)algorithm.Problem;
    484547      OnProblemChanged();
     
    491554      SamplesStart.Value = 0;
    492555      if (Problem != null) {
    493         Problem.ProblemDataChanged += (object sender, EventArgs e) => OnProblemChanged();
    494556        SamplesEnd.Value = Problem.ProblemData.Dataset.Rows;
    495557
     
    510572      } else
    511573        SamplesEnd.Value = 0;
     574
     575      SamplesStart_ValueChanged(this, EventArgs.Empty);
     576      SamplesEnd_ValueChanged(this, EventArgs.Empty);
    512577    }
    513578
     
    656721      stopPending = false;
    657722      Dictionary<string, IItem> collectedResults = new Dictionary<string, IItem>();
    658       CollectResultValues(collectedResults);
     723      AggregateResultValues(collectedResults);
    659724      results.AddRange(collectedResults.Select(x => new Result(x.Key, x.Value)).Cast<IResult>().ToArray());
    660725      runsCounter++;
Note: See TracChangeset for help on using the changeset viewer.