Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/01/11 17:48:53 (13 years ago)
Author:
mkommend
Message:

#1479: Integrated trunk changes.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs

    r6415 r6618  
    2323using System.Linq;
    2424using HeuristicLab.Common;
    25 using HeuristicLab.Data;
    26 using HeuristicLab.Optimization;
    2725using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2826
     
    3230  /// </summary>
    3331  [StorableClass]
    34   public class ClassificationSolution : DataAnalysisSolution, IClassificationSolution {
    35     private const string TrainingAccuracyResultName = "Accuracy (training)";
    36     private const string TestAccuracyResultName = "Accuracy (test)";
    37 
    38     public new IClassificationModel Model {
    39       get { return (IClassificationModel)base.Model; }
    40       protected set { base.Model = value; }
    41     }
    42 
    43     public new IClassificationProblemData ProblemData {
    44       get { return (IClassificationProblemData)base.ProblemData; }
    45       protected set { base.ProblemData = value; }
    46     }
    47 
    48     public double TrainingAccuracy {
    49       get { return ((DoubleValue)this[TrainingAccuracyResultName].Value).Value; }
    50       private set { ((DoubleValue)this[TrainingAccuracyResultName].Value).Value = value; }
    51     }
    52 
    53     public double TestAccuracy {
    54       get { return ((DoubleValue)this[TestAccuracyResultName].Value).Value; }
    55       private set { ((DoubleValue)this[TestAccuracyResultName].Value).Value = value; }
    56     }
     32  public abstract class ClassificationSolution : ClassificationSolutionBase {
     33    protected readonly Dictionary<int, double> evaluationCache;
    5734
    5835    [StorableConstructor]
    59     protected ClassificationSolution(bool deserializing) : base(deserializing) { }
     36    protected ClassificationSolution(bool deserializing)
     37      : base(deserializing) {
     38      evaluationCache = new Dictionary<int, double>();
     39    }
    6040    protected ClassificationSolution(ClassificationSolution original, Cloner cloner)
    6141      : base(original, cloner) {
     42      evaluationCache = new Dictionary<int, double>(original.evaluationCache);
    6243    }
    6344    public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData)
    6445      : base(model, problemData) {
    65       Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue()));
    66       Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue()));
    67       CalculateResults();
     46      evaluationCache = new Dictionary<int, double>();
    6847    }
    6948
    70     public override IDeepCloneable Clone(Cloner cloner) {
    71       return new ClassificationSolution(this, cloner);
     49    public override IEnumerable<double> EstimatedClassValues {
     50      get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     51    }
     52    public override IEnumerable<double> EstimatedTrainingClassValues {
     53      get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
     54    }
     55    public override IEnumerable<double> EstimatedTestClassValues {
     56      get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
    7257    }
    7358
    74     protected override void RecalculateResults() {
    75       CalculateResults();
     59    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     60      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
     61      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
     62      var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
     63
     64      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     65        evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
     66      }
     67
     68      return rows.Select(row => evaluationCache[row]);
    7669    }
    7770
    78     private void CalculateResults() {
    79       double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values
    80       IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    81       double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values
    82       IEnumerable<double> originalTestClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    83 
    84       OnlineCalculatorError errorState;
    85       double trainingAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTrainingClassValues, originalTrainingClassValues, out errorState);
    86       if (errorState != OnlineCalculatorError.None) trainingAccuracy = double.NaN;
    87       double testAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTestClassValues, originalTestClassValues, out errorState);
    88       if (errorState != OnlineCalculatorError.None) testAccuracy = double.NaN;
    89 
    90       TrainingAccuracy = trainingAccuracy;
    91       TestAccuracy = testAccuracy;
     71    protected override void OnProblemDataChanged() {
     72      evaluationCache.Clear();
     73      base.OnProblemDataChanged();
    9274    }
    9375
    94     public virtual IEnumerable<double> EstimatedClassValues {
    95       get {
    96         return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
    97       }
    98     }
    99 
    100     public virtual IEnumerable<double> EstimatedTrainingClassValues {
    101       get {
    102         return GetEstimatedClassValues(ProblemData.TrainingIndizes);
    103       }
    104     }
    105 
    106     public virtual IEnumerable<double> EstimatedTestClassValues {
    107       get {
    108         return GetEstimatedClassValues(ProblemData.TestIndizes);
    109       }
    110     }
    111 
    112     public virtual IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
    113       return Model.GetEstimatedClassValues(ProblemData.Dataset, rows);
     76    protected override void OnModelChanged() {
     77      evaluationCache.Clear();
     78      base.OnModelChanged();
    11479    }
    11580  }
Note: See TracChangeset for help on using the changeset viewer.