Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/29/08 11:21:04 (15 years ago)
Author:
gkronber
Message:

fixed #328 by restructuring evaluation operators to remove state in evaluation operators.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.GP.StructureIdentification.Classification/AccuracyEvaluator.cs

    r668 r702  
    2727using HeuristicLab.Data;
    2828using HeuristicLab.GP.StructureIdentification;
     29using HeuristicLab.DataAnalysis;
    2930
    3031namespace HeuristicLab.GP.StructureIdentification.Classification {
    31   public class AccuracyEvaluator : GPEvaluatorBase {
     32  public class AccuracyEvaluator : GPClassificationEvaluatorBase {
    3233    private const double EPSILON = 1.0E-6;
    33     private double[] classesArr;
    34     private double[] thresholds;
    35     private DoubleData accuracy;
    3634    public override string Description {
    3735      get {
     
    4341      : base() {
    4442      AddVariableInfo(new VariableInfo("Accuracy", "The total accuracy of the model (ratio of correctly classified instances to total number of instances)", typeof(DoubleData), VariableKind.New));
    45       AddVariableInfo(new VariableInfo("TargetClassValues", "The original class values of target variable (for instance negative=0 and positive=1).", typeof(ItemList<DoubleData>), VariableKind.In));
    4643    }
    4744
    48     public override IOperation Apply(IScope scope) {
    49       accuracy = GetVariableValue<DoubleData>("Accuracy", scope, false, false);
     45    public override void Evaluate(IScope scope, BakedTreeEvaluator evaluator, Dataset dataset, int targetVariable, double[] classes, double[] thresholds, int start, int end) {
     46      DoubleData accuracy = GetVariableValue<DoubleData>("Accuracy", scope, false, false);
    5047      if(accuracy == null) {
    5148        accuracy = new DoubleData();
     
    5350      }
    5451
    55       ItemList<DoubleData> classes = GetVariableValue<ItemList<DoubleData>>("TargetClassValues", scope, true);
    56       classesArr = new double[classes.Count];
    57       for(int i = 0; i < classesArr.Length; i++) classesArr[i] = classes[i].Data;
    58       Array.Sort(classesArr);
    59       thresholds = new double[classes.Count - 1];
    60       for(int i = 0; i < classesArr.Length - 1; i++) {
    61         thresholds[i] = (classesArr[i] + classesArr[i + 1]) / 2.0;
    62       }
    63 
    64       return base.Apply(scope);
    65     }
    66 
    67     public override void Evaluate(int start, int end) {
    6852      int nSamples = end - start;
    6953      int nCorrect = 0;
    7054      for(int sample = start; sample < end; sample++) {
    71         double est = GetEstimatedValue(sample);
    72         double origClass = GetOriginalValue(sample);
     55        double est = evaluator.Evaluate(sample);
     56        double origClass = dataset.GetValue(targetVariable, sample);
    7357        double estClass = double.NaN;
    7458        // if estimation is lower than the smallest threshold value -> estimated class is the lower class
    75         if(est < thresholds[0]) estClass = classesArr[0];
     59        if(est < thresholds[0]) estClass = classes[0];
    7660        // if estimation is larger (or equal) than the largest threshold value -> estimated class is the upper class
    77         else if(est >= thresholds[thresholds.Length - 1]) estClass = classesArr[classesArr.Length - 1];
     61        else if(est >= thresholds[thresholds.Length - 1]) estClass = classes[classes.Length - 1];
    7862        else {
    7963          // otherwise the estimated class is the class which upper threshold is larger than the estimated value
    8064          for(int k = 0; k < thresholds.Length; k++) {
    8165            if(thresholds[k] > est) {
    82               estClass = classesArr[k];
     66              estClass = classes[k];
    8367              break;
    8468            }
    8569          }
    8670        }
    87         SetOriginalValue(sample, estClass);
    8871        if(Math.Abs(estClass - origClass) < EPSILON) nCorrect++;
    8972      }
Note: See TracChangeset for help on using the changeset viewer.