Ignore:
Timestamp:
07/25/11 16:54:15 (11 years ago)
Author:
mkommend
Message:

#1600: Adapted classification solutions to the same design as used by regression solutions.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs

    r6411 r6589  
    2323using System.Linq;
    2424using HeuristicLab.Common;
    25 using HeuristicLab.Data;
    26 using HeuristicLab.Optimization;
    2725using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2826
     
    3230  /// </summary>
    3331  [StorableClass]
    34   public class ClassificationSolution : DataAnalysisSolution, IClassificationSolution {
    35     private const string TrainingAccuracyResultName = "Accuracy (training)";
    36     private const string TestAccuracyResultName = "Accuracy (test)";
    37 
    38     public new IClassificationModel Model {
    39       get { return (IClassificationModel)base.Model; }
    40       protected set { base.Model = value; }
    41     }
    42 
    43     public new IClassificationProblemData ProblemData {
    44       get { return (IClassificationProblemData)base.ProblemData; }
    45       protected set { base.ProblemData = value; }
    46     }
    47 
    48     public double TrainingAccuracy {
    49       get { return ((DoubleValue)this[TrainingAccuracyResultName].Value).Value; }
    50       private set { ((DoubleValue)this[TrainingAccuracyResultName].Value).Value = value; }
    51     }
    52 
    53     public double TestAccuracy {
    54       get { return ((DoubleValue)this[TestAccuracyResultName].Value).Value; }
    55       private set { ((DoubleValue)this[TestAccuracyResultName].Value).Value = value; }
    56     }
    57 
     32  public abstract class ClassificationSolution : ClassificationSolutionBase {
    5833    [StorableConstructor]
    5934    protected ClassificationSolution(bool deserializing) : base(deserializing) { }
     
    6338    public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData)
    6439      : base(model, problemData) {
    65       Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue()));
    66       Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue()));
    67       CalculateResults();
    6840    }
    6941
    70     public override IDeepCloneable Clone(Cloner cloner) {
    71       return new ClassificationSolution(this, cloner);
     42    public override IEnumerable<double> EstimatedClassValues {
     43      get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     44    }
     45    public override IEnumerable<double> EstimatedTrainingClassValues {
     46      get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
     47    }
     48    public override IEnumerable<double> EstimatedTestClassValues {
     49      get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
    7250    }
    7351
    74     protected override void RecalculateResults() {
    75       CalculateResults();
    76     }
    77 
    78     private void CalculateResults() {
    79       double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values
    80       IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    81       double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values
    82       IEnumerable<double> originalTestClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    83 
    84       OnlineCalculatorError errorState;
    85       double trainingAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTrainingClassValues, originalTrainingClassValues, out errorState);
    86       if (errorState != OnlineCalculatorError.None) trainingAccuracy = double.NaN;
    87       double testAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTestClassValues, originalTestClassValues, out errorState);
    88       if (errorState != OnlineCalculatorError.None) testAccuracy = double.NaN;
    89 
    90       TrainingAccuracy = trainingAccuracy;
    91       TestAccuracy = testAccuracy;
    92     }
    93 
    94     public virtual IEnumerable<double> EstimatedClassValues {
    95       get {
    96         return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
    97       }
    98     }
    99 
    100     public virtual IEnumerable<double> EstimatedTrainingClassValues {
    101       get {
    102         return GetEstimatedClassValues(ProblemData.TrainingIndizes);
    103       }
    104     }
    105 
    106     public virtual IEnumerable<double> EstimatedTestClassValues {
    107       get {
    108         return GetEstimatedClassValues(ProblemData.TestIndizes);
    109       }
    110     }
    111 
    112     public virtual IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     52    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
    11353      return Model.GetEstimatedClassValues(ProblemData.Dataset, rows);
    11454    }
Note: See TracChangeset for help on using the changeset viewer.