Changeset 5885


Ignore:
Timestamp:
03/30/11 13:57:33 (11 years ago)
Author:
gkronber
Message:

#1418 added R² and MSE as results for discriminant function classification solutions.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs

    r5849 r5885  
    2626using HeuristicLab.Core;
    2727using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     28using HeuristicLab.Data;
     29using HeuristicLab.Optimization;
    2830
    2931namespace HeuristicLab.Problems.DataAnalysis {
     
    3436  [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")]
    3537  public class DiscriminantFunctionClassificationSolution : ClassificationSolution, IDiscriminantFunctionClassificationSolution {
     38    private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
     39    private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
     40    private const string TrainingRSquaredResultName = "Pearson's R² (training)";
     41    private const string TestRSquaredResultName = "Pearson's R² (test)";
     42
    3643    public new IDiscriminantFunctionClassificationModel Model {
    3744      get { return (IDiscriminantFunctionClassificationModel)base.Model; }
     
    4754    }
    4855
     56    public double TrainingMeanSquaredError {
     57      get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
     58      private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
     59    }
     60
     61    public double TestMeanSquaredError {
     62      get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
     63      private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
     64    }
     65
     66    public double TrainingRSquared {
     67      get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; }
     68      private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; }
     69    }
     70
     71    public double TestRSquared {
     72      get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; }
     73      private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; }
     74    }
     75
    4976    [StorableConstructor]
    5077    protected DiscriminantFunctionClassificationSolution(bool deserializing) : base(deserializing) { }
     
    5885    public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
    5986      : base(model, problemData) {
     87      Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
     88      Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
     89      Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
     90      Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
    6091      RegisterEventHandler();
    6192      SetAccuracyMaximizingThresholds();
     93      RecalculateResults();
    6294    }
    6395
     
    6597    private void AfterDeserialization() {
    6698      RegisterEventHandler();
     99    }
     100
     101    protected void RecalculateResults() {
     102      double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
     103      IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
     104      double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
     105      IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
     106
     107      double trainingMSE = OnlineMeanSquaredErrorEvaluator.Calculate(estimatedTrainingValues, originalTrainingValues);
     108      double testMSE = OnlineMeanSquaredErrorEvaluator.Calculate(estimatedTestValues, originalTestValues);
     109      double trainingR2 = OnlinePearsonsRSquaredEvaluator.Calculate(estimatedTrainingValues, originalTrainingValues);
     110      double testR2 = OnlinePearsonsRSquaredEvaluator.Calculate(estimatedTestValues, originalTestValues);
     111
     112      TrainingMeanSquaredError = trainingMSE;
     113      TestMeanSquaredError = testMSE;
     114      TrainingRSquared = trainingR2;
     115      TestRSquared = testR2;
    67116    }
    68117
     
    95144      base.OnModelChanged(e);
    96145      SetAccuracyMaximizingThresholds();
     146      RecalculateResults();
    97147    }
    98148
     
    100150      base.OnProblemDataChanged(e);
    101151      SetAccuracyMaximizingThresholds();
     152      RecalculateResults();
    102153    }
    103154    protected virtual void OnModelThresholdsChanged(EventArgs e) {
    104155      base.OnModelChanged(e);
     156      RecalculateResults();
    105157    }
    106158
Note: See TracChangeset for help on using the changeset viewer.