Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/01/11 17:48:53 (13 years ago)
Author:
mkommend
Message:

#1479: Integrated trunk changes.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs

    r6415 r6618  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
    2524using HeuristicLab.Common;
    2625using HeuristicLab.Core;
    27 using HeuristicLab.Data;
    28 using HeuristicLab.Optimization;
    2926using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3027
     
    3532  [StorableClass]
    3633  [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")]
    37   public class DiscriminantFunctionClassificationSolution : ClassificationSolution, IDiscriminantFunctionClassificationSolution {
    38     private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
    39     private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
    40     private const string TrainingRSquaredResultName = "Pearson's R² (training)";
    41     private const string TestRSquaredResultName = "Pearson's R² (test)";
     34  public abstract class DiscriminantFunctionClassificationSolution : DiscriminantFunctionClassificationSolutionBase {
     35    protected readonly Dictionary<int, double> valueEvaluationCache;
     36    protected readonly Dictionary<int, double> classValueEvaluationCache;
    4237
    43     public new IDiscriminantFunctionClassificationModel Model {
    44       get { return (IDiscriminantFunctionClassificationModel)base.Model; }
    45       protected set {
    46         if (value != null && value != Model) {
    47           if (Model != null) {
    48             Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged);
    49           }
    50           value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
    51           base.Model = value;
    52         }
    53       }
     38    [StorableConstructor]
     39    protected DiscriminantFunctionClassificationSolution(bool deserializing)
     40      : base(deserializing) {
     41      valueEvaluationCache = new Dictionary<int, double>();
     42      classValueEvaluationCache = new Dictionary<int, double>();
     43    }
     44    protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner)
     45      : base(original, cloner) {
     46      valueEvaluationCache = new Dictionary<int, double>(original.valueEvaluationCache);
     47      classValueEvaluationCache = new Dictionary<int, double>(original.classValueEvaluationCache);
     48    }
     49    protected DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
     50      : base(model, problemData) {
     51      valueEvaluationCache = new Dictionary<int, double>();
     52      classValueEvaluationCache = new Dictionary<int, double>();
     53
     54      SetAccuracyMaximizingThresholds();
    5455    }
    5556
    56     public double TrainingMeanSquaredError {
    57       get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
    58       private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
     57    public override IEnumerable<double> EstimatedClassValues {
     58      get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     59    }
     60    public override IEnumerable<double> EstimatedTrainingClassValues {
     61      get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
     62    }
     63    public override IEnumerable<double> EstimatedTestClassValues {
     64      get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
    5965    }
    6066
    61     public double TestMeanSquaredError {
    62       get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
    63       private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
     67    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     68      var rowsToEvaluate = rows.Except(classValueEvaluationCache.Keys);
     69      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     70      var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     71
     72      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     73        classValueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     74      }
     75
     76      return rows.Select(row => classValueEvaluationCache[row]);
    6477    }
    6578
    66     public double TrainingRSquared {
    67       get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; }
    68       private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; }
    69     }
    7079
    71     public double TestRSquared {
    72       get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; }
    73       private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; }
    74     }
    75 
    76     [StorableConstructor]
    77     protected DiscriminantFunctionClassificationSolution(bool deserializing) : base(deserializing) { }
    78     protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner)
    79       : base(original, cloner) {
    80       RegisterEventHandler();
    81     }
    82     public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData)
    83       : this(new DiscriminantFunctionClassificationModel(model), problemData) {
    84     }
    85     public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
    86       : base(model, problemData) {
    87       Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
    88       Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
    89       Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
    90       Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
    91       SetAccuracyMaximizingThresholds();
    92 
    93       //mkommend: important to recalculate accuracy because during the calculation before no thresholds were present     
    94       base.RecalculateResults();
    95       CalculateResults();
    96       RegisterEventHandler();
    97     }
    98 
    99     [StorableHook(HookType.AfterDeserialization)]
    100     private void AfterDeserialization() {
    101       RegisterEventHandler();
    102     }
    103 
    104     protected override void OnModelChanged(EventArgs e) {
    105       DeregisterEventHandler();
    106       SetAccuracyMaximizingThresholds();
    107       RegisterEventHandler();
    108       base.OnModelChanged(e);
    109     }
    110 
    111     protected override void RecalculateResults() {
    112       base.RecalculateResults();
    113       CalculateResults();
    114     }
    115 
    116     private void CalculateResults() {
    117       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    118       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    119       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    120       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    121 
    122       OnlineCalculatorError errorState;
    123       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    124       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    125       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    126       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    127 
    128       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    129       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    130       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    131       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    132     }
    133 
    134     private void RegisterEventHandler() {
    135       Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
    136     }
    137     private void DeregisterEventHandler() {
    138       Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged);
    139     }
    140     private void Model_ThresholdsChanged(object sender, EventArgs e) {
    141       OnModelThresholdsChanged(e);
    142     }
    143 
    144     public void SetAccuracyMaximizingThresholds() {
    145       double[] classValues;
    146       double[] thresholds;
    147       var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    148       AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    149 
    150       Model.SetThresholdsAndClassValues(thresholds, classValues);
    151     }
    152 
    153     public void SetClassDistibutionCutPointThresholds() {
    154       double[] classValues;
    155       double[] thresholds;
    156       var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    157       NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    158 
    159       Model.SetThresholdsAndClassValues(thresholds, classValues);
    160     }
    161 
    162     protected virtual void OnModelThresholdsChanged(EventArgs e) {
    163       RecalculateResults();
    164     }
    165 
    166     public IEnumerable<double> EstimatedValues {
     80    public override IEnumerable<double> EstimatedValues {
    16781      get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
    16882    }
    169 
    170     public IEnumerable<double> EstimatedTrainingValues {
     83    public override IEnumerable<double> EstimatedTrainingValues {
    17184      get { return GetEstimatedValues(ProblemData.TrainingIndizes); }
    17285    }
    173 
    174     public IEnumerable<double> EstimatedTestValues {
     86    public override IEnumerable<double> EstimatedTestValues {
    17587      get { return GetEstimatedValues(ProblemData.TestIndizes); }
    17688    }
    17789
    178     public IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
    179       return Model.GetEstimatedValues(ProblemData.Dataset, rows);
     90    public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
     91      var rowsToEvaluate = rows.Except(valueEvaluationCache.Keys);
     92      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     93      var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     94
     95      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     96        valueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     97      }
     98
     99      return rows.Select(row => valueEvaluationCache[row]);
     100    }
     101
     102    protected override void OnModelChanged() {
     103      valueEvaluationCache.Clear();
     104      classValueEvaluationCache.Clear();
     105      base.OnModelChanged();
     106    }
     107    protected override void OnModelThresholdsChanged(System.EventArgs e) {
     108      classValueEvaluationCache.Clear();
     109      base.OnModelThresholdsChanged(e);
     110    }
     111    protected override void OnProblemDataChanged() {
     112      valueEvaluationCache.Clear();
     113      classValueEvaluationCache.Clear();
     114      base.OnProblemDataChanged();
    180115    }
    181116  }
Note: See TracChangeset for help on using the changeset viewer.