Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
05/11/11 15:56:17 (13 years ago)
Author:
gkronber
Message:

#1450: merged r5816 from the branch and implemented first version of ensemble solutions for regression. The ensembles are only produced by cross validation.

Location:
trunk/sources
Files:
3 edited
1 copied

Legend:

Unmodified
Added
Removed
  • trunk/sources

  • trunk/sources/HeuristicLab.Problems.DataAnalysis

  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r6180 r6184  
    7575        throw new ArgumentException();
    7676      }
     77
     78      RecalculateResults();
     79    }
     80
     81    private void RecalculateResults() {
     82      double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
     83      var trainingIndizes = Enumerable.Range(ProblemData.TrainingPartition.Start,
     84        ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);
     85      IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, trainingIndizes);
     86      double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
     87      IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
     88
     89      OnlineCalculatorError errorState;
     90      double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
     91      TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
     92      double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
     93      TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
     94
     95      double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
     96      TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
     97      double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
     98      TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
     99
     100      double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
     101      TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN;
     102      double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
     103      TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN;
     104
     105      double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
     106      TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN;
     107      double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
     108      TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN;
    77109    }
    78110
     
    83115    public override IEnumerable<double> EstimatedTrainingValues {
    84116      get {
     117        var rows = Enumerable.Range(ProblemData.TrainingPartition.Start, ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);
    85118        var estimatedValuesEnumerators = (from model in Model.Models
    86                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, ProblemData.TestIndizes).GetEnumerator() })
     119                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    87120                                         .ToList();
    88         var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
    89         while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.All(en => en.EstimatedValuesEnumerator.MoveNext())) {
     121        var rowsEnumerator = rows.GetEnumerator();
     122        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    90123          int currentRow = rowsEnumerator.Current;
    91124
    92125          var selectedEnumerators = from pair in estimatedValuesEnumerators
    93126                                    where trainingPartitions == null || !trainingPartitions.ContainsKey(pair.Model) ||
    94                                          (trainingPartitions[pair.Model].Start >= currentRow && trainingPartitions[pair.Model].End < currentRow)
     127                                         (trainingPartitions[pair.Model].Start <= currentRow && currentRow < trainingPartitions[pair.Model].End)
    95128                                    select pair.EstimatedValuesEnumerator;
    96129          yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
     
    105138                                         .ToList();
    106139        var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
    107         while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.All(en => en.EstimatedValuesEnumerator.MoveNext())) {
     140        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    108141          int currentRow = rowsEnumerator.Current;
    109142
    110143          var selectedEnumerators = from pair in estimatedValuesEnumerators
    111144                                    where testPartitions == null || !testPartitions.ContainsKey(pair.Model) ||
    112                                       (testPartitions[pair.Model].Start >= currentRow && testPartitions[pair.Model].End < currentRow)
     145                                      (testPartitions[pair.Model].Start <= currentRow && currentRow < testPartitions[pair.Model].End)
    113146                                    select pair.EstimatedValuesEnumerator;
    114147
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolution.cs

    r5962 r6184  
    3030namespace HeuristicLab.Problems.DataAnalysis {
    3131  /// <summary>
    32   /// Abstract base class for regression data analysis solutions
     32  /// Represents a regression data analysis solution
    3333  /// </summary>
    3434  [StorableClass]
    35   public abstract class RegressionSolution : DataAnalysisSolution, IRegressionSolution {
     35  public class RegressionSolution : DataAnalysisSolution, IRegressionSolution {
    3636    private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
    3737    private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
     
    5555    public double TrainingMeanSquaredError {
    5656      get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
    57       private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
     57      protected set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
    5858    }
    5959
    6060    public double TestMeanSquaredError {
    6161      get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
    62       private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
     62      protected set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
    6363    }
    6464
    6565    public double TrainingRSquared {
    6666      get { return ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value; }
    67       private set { ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value = value; }
     67      protected set { ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value = value; }
    6868    }
    6969
    7070    public double TestRSquared {
    7171      get { return ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value; }
    72       private set { ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value = value; }
     72      protected set { ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value = value; }
    7373    }
    7474
    7575    public double TrainingRelativeError {
    7676      get { return ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value; }
    77       private set { ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value = value; }
     77      protected set { ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value = value; }
    7878    }
    7979
    8080    public double TestRelativeError {
    8181      get { return ((DoubleValue)this[TestRelativeErrorResultName].Value).Value; }
    82       private set { ((DoubleValue)this[TestRelativeErrorResultName].Value).Value = value; }
     82      protected set { ((DoubleValue)this[TestRelativeErrorResultName].Value).Value = value; }
    8383    }
    8484
    8585    public double TrainingNormalizedMeanSquaredError {
    8686      get { return ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value; }
    87       private set { ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value = value; }
     87      protected set { ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value = value; }
    8888    }
    8989
    9090    public double TestNormalizedMeanSquaredError {
    9191      get { return ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value; }
    92       private set { ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value = value; }
     92      protected set { ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value = value; }
    9393    }
    9494
     
    113113    }
    114114
     115    public override IDeepCloneable Clone(Cloner cloner) {
     116      return new RegressionSolution(this, cloner);
     117    }
     118
    115119    protected override void OnProblemDataChanged(EventArgs e) {
    116120      base.OnProblemDataChanged(e);
     
    122126    }
    123127
    124     protected void RecalculateResults() {
     128    private void RecalculateResults() {
    125129      double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    126130      IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
Note: See TracChangeset for help on using the changeset viewer.