Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/10/11 10:00:09 (13 years ago)
Author:
gkronber
Message:

#1418 Implemented classes for classification based on a discriminant function and thresholds and implemented interfaces and base classes for clustering.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/ClassificationSolution.cs

    r5624 r5649  
    3737  [StorableClass]
    3838  public abstract class ClassificationSolution : DataAnalysisSolution, IClassificationSolution {
    39     private const string ThresholdsResultsName = "Thresholds";
     39    private const string TrainingAccuracyResultName = "Accuracy (training)";
     40    private const string TestAccuracyResultName = "Accuracy (test)";
    4041    [StorableConstructor]
    4142    protected ClassificationSolution(bool deserializing) : base(deserializing) { }
     
    4546    public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData)
    4647      : base(model, problemData) {
    47       DoubleArray thresholds = new DoubleArray();
    48       Add(new Result(ThresholdsResultsName, "The threshold values for class boundaries.", thresholds));
    49       thresholds.Reset += new EventHandler(thresholds_Reset);
    50       thresholds.ItemChanged += new EventHandler<EventArgs<int>>(thresholds_ItemChanged);
     48      double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values
     49      IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
     50      double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values
     51      IEnumerable<double> originalTestClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
     52
     53      double trainingAccuracy = OnlineAccuracyEvaluator.Calculate(estimatedTrainingClassValues, originalTrainingClassValues);
     54      double testAccuracy = OnlineAccuracyEvaluator.Calculate(estimatedTestClassValues, originalTestClassValues);
     55
     56      Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue(trainingAccuracy)));
     57      Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue(testAccuracy)));
    5158    }
    5259
     
    6168    }
    6269
    63     public virtual IEnumerable<double> EstimatedValues {
    64       get {
    65         return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
    66       }
    67     }
    68 
    69     public virtual IEnumerable<double> EstimatedTrainingValues {
    70       get {
    71         return GetEstimatedValues(ProblemData.TrainingIndizes);
    72       }
    73     }
    74 
    75     public virtual IEnumerable<double> EstimatedTestValues {
    76       get {
    77         return GetEstimatedValues(ProblemData.TestIndizes);
    78       }
    79     }
    80 
    81     public virtual IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
    82       return Model.GetEstimatedValues(ProblemData, rows);
    83     }
    84 
    85     public IEnumerable<double> Thresholds {
    86       get {
    87         return (DoubleArray)this[ThresholdsResultsName].Value;
    88       }
    89     }
    90 
    91     public IEnumerable<double> EstimatedClassValues {
     70    public virtual IEnumerable<double> EstimatedClassValues {
    9271      get {
    9372        return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
     
    9574    }
    9675
    97     public IEnumerable<double> EstimatedTrainingClassValues {
     76    public virtual IEnumerable<double> EstimatedTrainingClassValues {
    9877      get {
    9978        return GetEstimatedClassValues(ProblemData.TrainingIndizes);
     
    10180    }
    10281
    103     public IEnumerable<double> EstimatedTestClassValues {
     82    public virtual IEnumerable<double> EstimatedTestClassValues {
    10483      get {
    10584        return GetEstimatedClassValues(ProblemData.TestIndizes);
     
    10786    }
    10887
    109     public IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
    110       return Model.GetEstimatedClassValues(ProblemData, rows);
    111     }
    112 
    113     #endregion
    114     #region events
    115     private void thresholds_ItemChanged(object sender, EventArgs<int> e) {
    116       OnThresholdsChanged(EventArgs.Empty);
    117     }
    118 
    119     private void thresholds_Reset(object sender, EventArgs e) {
    120       OnThresholdsChanged(EventArgs.Empty);
    121     }
    122 
    123     public event EventHandler ThresholdsChanged;
    124     private void OnThresholdsChanged(EventArgs e) {
    125       var listeners = ThresholdsChanged;
    126       if (listeners != null) listeners(this, e);
     88    public virtual IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     89      return Model.GetEstimatedClassValues(ProblemData.Dataset, rows);
    12790    }
    12891    #endregion
Note: See TracChangeset for help on using the changeset viewer.