Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
05/20/11 15:07:45 (13 years ago)
Author:
gkronber
Message:

#1450 adapted views for regression solution to work for ensembles of regression solutions as well.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r6184 r6238  
    5151    }
    5252    public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)
    53       : base(new RegressionEnsembleModel(models), problemData) {
     53      : base(new RegressionEnsembleModel(models), new RegressionEnsembleProblemData(problemData)) {
    5454      trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    5555      testPartitions = new Dictionary<IRegressionModel, IntRange>();
     
    6262
    6363    public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
    64       : base(new RegressionEnsembleModel(models), problemData) {
     64      : base(new RegressionEnsembleModel(models), new RegressionEnsembleProblemData(problemData)) {
    6565      this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    6666      this.testPartitions = new Dictionary<IRegressionModel, IntRange>();
     
    7575        throw new ArgumentException();
    7676      }
    77 
    7877      RecalculateResults();
    79     }
    80 
    81     private void RecalculateResults() {
    82       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    83       var trainingIndizes = Enumerable.Range(ProblemData.TrainingPartition.Start,
    84         ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);
    85       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, trainingIndizes);
    86       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    87       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    88 
    89       OnlineCalculatorError errorState;
    90       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    91       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    92       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    93       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    94 
    95       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    96       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    97       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    98       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    99 
    100       double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    101       TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN;
    102       double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    103       TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN;
    104 
    105       double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    106       TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN;
    107       double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    108       TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN;
    10978    }
    11079
     
    11584    public override IEnumerable<double> EstimatedTrainingValues {
    11685      get {
    117         var rows = Enumerable.Range(ProblemData.TrainingPartition.Start, ProblemData.TrainingPartition.End - ProblemData.TrainingPartition.Start);
     86        var rows = ProblemData.TrainingIndizes;
    11887        var estimatedValuesEnumerators = (from model in Model.Models
    11988                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    12089                                         .ToList();
    12190        var rowsEnumerator = rows.GetEnumerator();
     91        // aggregate to make sure that MoveNext is called for all enumerators
    12292        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    12393          int currentRow = rowsEnumerator.Current;
     
    134104    public override IEnumerable<double> EstimatedTestValues {
    135105      get {
     106        var rows = ProblemData.TestIndizes;
    136107        var estimatedValuesEnumerators = (from model in Model.Models
    137                                           select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, ProblemData.TestIndizes).GetEnumerator() })
     108                                          select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    138109                                         .ToList();
    139110        var rowsEnumerator = ProblemData.TestIndizes.GetEnumerator();
     111        // aggregate to make sure that MoveNext is called for all enumerators
    140112        while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    141113          int currentRow = rowsEnumerator.Current;
     
    168140
    169141    private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) {
    170       return estimatedValues.Average();
     142      return estimatedValues.DefaultIfEmpty(double.NaN).Average();
    171143    }
    172144
Note: See TracChangeset for help on using the changeset viewer.