Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/01/11 17:48:53 (13 years ago)
Author:
mkommend
Message:

#1479: Integrated trunk changes.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolution.cs

    r6415 r6618  
    2323using System.Linq;
    2424using HeuristicLab.Common;
    25 using HeuristicLab.Data;
    26 using HeuristicLab.Optimization;
    2725using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2826
     
    3230  /// </summary>
    3331  [StorableClass]
    34   public class RegressionSolution : DataAnalysisSolution, IRegressionSolution {
    35     private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
    36     private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
    37     private const string TrainingSquaredCorrelationResultName = "Pearson's R² (training)";
    38     private const string TestSquaredCorrelationResultName = "Pearson's R² (test)";
    39     private const string TrainingRelativeErrorResultName = "Average relative error (training)";
    40     private const string TestRelativeErrorResultName = "Average relative error (test)";
    41     private const string TrainingNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (training)";
    42     private const string TestNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (test)";
    43 
    44     public new IRegressionModel Model {
    45       get { return (IRegressionModel)base.Model; }
    46       protected set { base.Model = value; }
    47     }
    48 
    49     public new IRegressionProblemData ProblemData {
    50       get { return (IRegressionProblemData)base.ProblemData; }
    51       protected set { base.ProblemData = value; }
    52     }
    53 
    54     public double TrainingMeanSquaredError {
    55       get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
    56       private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
    57     }
    58 
    59     public double TestMeanSquaredError {
    60       get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
    61       private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
    62     }
    63 
    64     public double TrainingRSquared {
    65       get { return ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value; }
    66       private set { ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value = value; }
    67     }
    68 
    69     public double TestRSquared {
    70       get { return ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value; }
    71       private set { ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value = value; }
    72     }
    73 
    74     public double TrainingRelativeError {
    75       get { return ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value; }
    76       private set { ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value = value; }
    77     }
    78 
    79     public double TestRelativeError {
    80       get { return ((DoubleValue)this[TestRelativeErrorResultName].Value).Value; }
    81       private set { ((DoubleValue)this[TestRelativeErrorResultName].Value).Value = value; }
    82     }
    83 
    84     public double TrainingNormalizedMeanSquaredError {
    85       get { return ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value; }
    86       private set { ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value = value; }
    87     }
    88 
    89     public double TestNormalizedMeanSquaredError {
    90       get { return ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value; }
    91       private set { ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value = value; }
    92     }
    93 
     32  public abstract class RegressionSolution : RegressionSolutionBase {
     33    protected readonly Dictionary<int, double> evaluationCache;
    9434
    9535    [StorableConstructor]
    96     protected RegressionSolution(bool deserializing) : base(deserializing) { }
     36    protected RegressionSolution(bool deserializing)
     37      : base(deserializing) {
     38      evaluationCache = new Dictionary<int, double>();
     39    }
    9740    protected RegressionSolution(RegressionSolution original, Cloner cloner)
    9841      : base(original, cloner) {
     42      evaluationCache = new Dictionary<int, double>(original.evaluationCache);
    9943    }
    100     public RegressionSolution(IRegressionModel model, IRegressionProblemData problemData)
     44    protected RegressionSolution(IRegressionModel model, IRegressionProblemData problemData)
    10145      : base(model, problemData) {
    102       Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
    103       Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
    104       Add(new Result(TrainingSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
    105       Add(new Result(TestSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
    106       Add(new Result(TrainingRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the training partition", new PercentValue()));
    107       Add(new Result(TestRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the test partition", new PercentValue()));
    108       Add(new Result(TrainingNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the training partition", new DoubleValue()));
    109       Add(new Result(TestNormalizedMeanSquaredErrorResultName, "Normalized mean of squared errors of the model on the test partition", new DoubleValue()));
    110 
    111       CalculateResults();
    112     }
    113 
    114     public override IDeepCloneable Clone(Cloner cloner) {
    115       return new RegressionSolution(this, cloner);
     46      evaluationCache = new Dictionary<int, double>();
    11647    }
    11748
     
    12051    }
    12152
    122     private void CalculateResults() {
    123       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    124       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    125       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    126       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    127 
    128       OnlineCalculatorError errorState;
    129       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    130       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    131       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    132       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    133 
    134       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    135       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    136       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    137       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    138 
    139       double trainingRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    140       TrainingRelativeError = errorState == OnlineCalculatorError.None ? trainingRelError : double.NaN;
    141       double testRelError = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    142       TestRelativeError = errorState == OnlineCalculatorError.None ? testRelError : double.NaN;
    143 
    144       double trainingNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    145       TrainingNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingNMSE : double.NaN;
    146       double testNMSE = OnlineNormalizedMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    147       TestNormalizedMeanSquaredError = errorState == OnlineCalculatorError.None ? testNMSE : double.NaN;
     53    public override IEnumerable<double> EstimatedValues {
     54      get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     55    }
     56    public override IEnumerable<double> EstimatedTrainingValues {
     57      get { return GetEstimatedValues(ProblemData.TrainingIndizes); }
     58    }
     59    public override IEnumerable<double> EstimatedTestValues {
     60      get { return GetEstimatedValues(ProblemData.TestIndizes); }
    14861    }
    14962
    150     public virtual IEnumerable<double> EstimatedValues {
    151       get {
    152         return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
     63    public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
     64      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
     65      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     66      var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     67
     68      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     69        evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
    15370      }
     71
     72      return rows.Select(row => evaluationCache[row]);
    15473    }
    15574
    156     public virtual IEnumerable<double> EstimatedTrainingValues {
    157       get {
    158         return GetEstimatedValues(ProblemData.TrainingIndizes);
    159       }
     75    protected override void OnProblemDataChanged() {
     76      evaluationCache.Clear();
     77      base.OnProblemDataChanged();
    16078    }
    16179
    162     public virtual IEnumerable<double> EstimatedTestValues {
    163       get {
    164         return GetEstimatedValues(ProblemData.TestIndizes);
    165       }
    166     }
    167 
    168     public virtual IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
    169       return Model.GetEstimatedValues(ProblemData.Dataset, rows);
     80    protected override void OnModelChanged() {
     81      evaluationCache.Clear();
     82      base.OnModelChanged();
    17083    }
    17184  }
Note: See TracChangeset for help on using the changeset viewer.