Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/29/08 11:21:04 (16 years ago)
Author:
gkronber
Message:

fixed #328 by restructuring evaluation operators to remove state in evaluation operators.

Location:
trunk/sources/HeuristicLab.GP.StructureIdentification.Classification
Files:
1 added
1 deleted
5 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.GP.StructureIdentification.Classification/AccuracyEvaluator.cs

    r668 r702  
    2727using HeuristicLab.Data;
    2828using HeuristicLab.GP.StructureIdentification;
     29using HeuristicLab.DataAnalysis;
    2930
    3031namespace HeuristicLab.GP.StructureIdentification.Classification {
    31   public class AccuracyEvaluator : GPEvaluatorBase {
     32  public class AccuracyEvaluator : GPClassificationEvaluatorBase {
    3233    private const double EPSILON = 1.0E-6;
    33     private double[] classesArr;
    34     private double[] thresholds;
    35     private DoubleData accuracy;
    3634    public override string Description {
    3735      get {
     
    4341      : base() {
    4442      AddVariableInfo(new VariableInfo("Accuracy", "The total accuracy of the model (ratio of correctly classified instances to total number of instances)", typeof(DoubleData), VariableKind.New));
    45       AddVariableInfo(new VariableInfo("TargetClassValues", "The original class values of target variable (for instance negative=0 and positive=1).", typeof(ItemList<DoubleData>), VariableKind.In));
    4643    }
    4744
    48     public override IOperation Apply(IScope scope) {
    49       accuracy = GetVariableValue<DoubleData>("Accuracy", scope, false, false);
     45    public override void Evaluate(IScope scope, BakedTreeEvaluator evaluator, Dataset dataset, int targetVariable, double[] classes, double[] thresholds, int start, int end) {
     46      DoubleData accuracy = GetVariableValue<DoubleData>("Accuracy", scope, false, false);
    5047      if(accuracy == null) {
    5148        accuracy = new DoubleData();
     
    5350      }
    5451
    55       ItemList<DoubleData> classes = GetVariableValue<ItemList<DoubleData>>("TargetClassValues", scope, true);
    56       classesArr = new double[classes.Count];
    57       for(int i = 0; i < classesArr.Length; i++) classesArr[i] = classes[i].Data;
    58       Array.Sort(classesArr);
    59       thresholds = new double[classes.Count - 1];
    60       for(int i = 0; i < classesArr.Length - 1; i++) {
    61         thresholds[i] = (classesArr[i] + classesArr[i + 1]) / 2.0;
    62       }
    63 
    64       return base.Apply(scope);
    65     }
    66 
    67     public override void Evaluate(int start, int end) {
    6852      int nSamples = end - start;
    6953      int nCorrect = 0;
    7054      for(int sample = start; sample < end; sample++) {
    71         double est = GetEstimatedValue(sample);
    72         double origClass = GetOriginalValue(sample);
     55        double est = evaluator.Evaluate(sample);
     56        double origClass = dataset.GetValue(targetVariable, sample);
    7357        double estClass = double.NaN;
    7458        // if estimation is lower than the smallest threshold value -> estimated class is the lower class
    75         if(est < thresholds[0]) estClass = classesArr[0];
     59        if(est < thresholds[0]) estClass = classes[0];
    7660        // if estimation is larger (or equal) than the largest threshold value -> estimated class is the upper class
    77         else if(est >= thresholds[thresholds.Length - 1]) estClass = classesArr[classesArr.Length - 1];
     61        else if(est >= thresholds[thresholds.Length - 1]) estClass = classes[classes.Length - 1];
    7862        else {
    7963          // otherwise the estimated class is the class which upper threshold is larger than the estimated value
    8064          for(int k = 0; k < thresholds.Length; k++) {
    8165            if(thresholds[k] > est) {
    82               estClass = classesArr[k];
     66              estClass = classes[k];
    8367              break;
    8468            }
    8569          }
    8670        }
    87         SetOriginalValue(sample, estClass);
    8871        if(Math.Abs(estClass - origClass) < EPSILON) nCorrect++;
    8972      }
  • trunk/sources/HeuristicLab.GP.StructureIdentification.Classification/ClassificationMeanSquaredErrorEvaluator.cs

    r668 r702  
    2929
    3030namespace HeuristicLab.GP.StructureIdentification.Classification {
    31   public class ClassificationMeanSquaredErrorEvaluator : MeanSquaredErrorEvaluator {
    32     private const double EPSILON = 1.0E-6;
    33     private double[] classesArr;
     31  public class ClassificationMeanSquaredErrorEvaluator : GPClassificationEvaluatorBase {
     32    private const double EPSILON = 1.0E-7;
    3433    public override string Description {
    3534      get {
     
    4140    public ClassificationMeanSquaredErrorEvaluator()
    4241      : base() {
    43       AddVariableInfo(new VariableInfo("TargetClassValues", "The original class values of target variable (for instance negative=0 and positive=1).", typeof(ItemList<DoubleData>), VariableKind.In));
     42      AddVariableInfo(new VariableInfo("MSE", "The mean squared error of the model", typeof(DoubleData), VariableKind.New));
    4443    }
    4544
    46     public override IOperation Apply(IScope scope) {
    47       ItemList<DoubleData> classes = GetVariableValue<ItemList<DoubleData>>("TargetClassValues", scope, true);
    48       classesArr = new double[classes.Count];
    49       for(int i = 0; i < classesArr.Length; i++) classesArr[i] = classes[i].Data;
    50       Array.Sort(classesArr);
    51       return base.Apply(scope);
    52     }
    53 
    54     public override void Evaluate(int start, int end) {
     45    public override void  Evaluate(IScope scope, BakedTreeEvaluator evaluator, HeuristicLab.DataAnalysis.Dataset dataset, int targetVariable, double[] classes, double[] thresholds, int start, int end)
     46{
    5547      double errorsSquaredSum = 0;
    5648      for(int sample = start; sample < end; sample++) {
    57         double estimated = GetEstimatedValue(sample);
    58         double original = GetOriginalValue(sample);
    59         SetOriginalValue(sample, estimated);
     49        double estimated = evaluator.Evaluate(sample);
     50        double original = dataset.GetValue(targetVariable, sample);
    6051        if(!double.IsNaN(original) && !double.IsInfinity(original)) {
    6152          double error = estimated - original;
     
    6354          // on the lower end and upper end only add linear error if the absolute error is larger than 1
    6455          // the error>1.0 constraint is needed for balance because in the interval ]-1, 1[ the squared error is smaller than the absolute error
    65           if((IsEqual(original, classesArr[0]) && error < -1.0) ||
    66             (IsEqual(original, classesArr[classesArr.Length - 1]) && error > 1.0)) {
     56          if((IsEqual(original, classes[0]) && error < -1.0) ||
     57            (IsEqual(original, classes[classes.Length - 1]) && error > 1.0)) {
    6758            errorsSquaredSum += Math.Abs(error); // only add linear error below the smallest class or above the largest class
    6859          } else {
     
    7667        errorsSquaredSum = double.MaxValue;
    7768      }
     69
     70      DoubleData mse = GetVariableValue<DoubleData>("MSE", scope, false, false);
     71      if(mse == null) {
     72        mse = new DoubleData();
     73        scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName("MSE"), mse));
     74      }
     75
    7876      mse.Data = errorsSquaredSum;
    7977    }
  • trunk/sources/HeuristicLab.GP.StructureIdentification.Classification/ConfusionMatrixEvaluator.cs

    r668 r702  
    2929
    3030namespace HeuristicLab.GP.StructureIdentification.Classification {
    31   public class ConfusionMatrixEvaluator : GPEvaluatorBase {
    32     private const double EPSILON = 1.0E-6;
    33     private double[] classesArr;
    34     private double[] thresholds;
    35     private IntMatrixData matrix;
     31  public class ConfusionMatrixEvaluator : GPClassificationEvaluatorBase {
    3632    public override string Description {
    3733      get {
     
    4339      : base() {
    4440      AddVariableInfo(new VariableInfo("ConfusionMatrix", "The confusion matrix of the model", typeof(IntMatrixData), VariableKind.New));
    45       AddVariableInfo(new VariableInfo("TargetClassValues", "The original class values of target variable (for instance negative=0 and positive=1).", typeof(ItemList<DoubleData>), VariableKind.In));
    4641    }
    4742
    48     public override IOperation Apply(IScope scope) {
    49       ItemList<DoubleData> classes = GetVariableValue<ItemList<DoubleData>>("TargetClassValues", scope, true);
    50       classesArr = new double[classes.Count];
    51       for(int i = 0; i < classesArr.Length; i++) classesArr[i] = classes[i].Data;
    52       Array.Sort(classesArr);
    53       thresholds = new double[classes.Count - 1];
    54       for(int i = 0; i < classesArr.Length - 1; i++) {
    55         thresholds[i] = (classesArr[i] + classesArr[i + 1]) / 2.0;
     43    public override void Evaluate(IScope scope, BakedTreeEvaluator evaluator, HeuristicLab.DataAnalysis.Dataset dataset, int targetVariable, double[] classes, double[] thresholds, int start, int end) {
     44      IntMatrixData matrix = GetVariableValue<IntMatrixData>("ConfusionMatrix", scope, false, false);
     45      if(matrix == null) {
     46        matrix = new IntMatrixData(new int[classes.Length, classes.Length]);
     47        scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName("ConfusionMatrix"), matrix));
    5648      }
    5749
    58       matrix = GetVariableValue<IntMatrixData>("ConfusionMatrix", scope, false, false);
    59       if(matrix == null) {
    60         matrix = new IntMatrixData(new int[classesArr.Length, classesArr.Length]);
    61         scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName("ConfusionMatrix"), matrix));
    62       }
    63       return base.Apply(scope);
    64     }
    65 
    66     public override void Evaluate(int start, int end) {
    6750      int nSamples = end - start;
    6851      for(int sample = start; sample < end; sample++) {
    69         double est = GetEstimatedValue(sample);
    70         double origClass = GetOriginalValue(sample);
     52        double est = evaluator.Evaluate(sample);
     53        double origClass = dataset.GetValue(targetVariable,sample);
    7154        int estClassIndex = -1;
    7255        // if estimation is lower than the smallest threshold value -> estimated class is the lower class
    7356        if(est < thresholds[0]) estClassIndex = 0;
    7457        // if estimation is larger (or equal) than the largest threshold value -> estimated class is the upper class
    75         else if(est >= thresholds[thresholds.Length - 1]) estClassIndex = classesArr.Length - 1;
     58        else if(est >= thresholds[thresholds.Length - 1]) estClassIndex = classes.Length - 1;
    7659        else {
    7760          // otherwise the estimated class is the class which upper threshold is larger than the estimated value
     
    8366          }
    8467        }
    85         SetOriginalValue(sample, classesArr[estClassIndex]);
    8668
    87         int origClassIndex = -1;
    88         for(int i = 0; i < classesArr.Length; i++) {
    89           if(IsEqual(origClass, classesArr[i])) origClassIndex = i;
     69        // find the first threshold index that is larger to the original value
     70        int origClassIndex = classes.Length-1;
     71        for(int i = 0; i < thresholds.Length; i++) {
     72          if(origClass < thresholds[i]) {
     73            origClassIndex = i;
     74            break;
     75          }
    9076        }
    9177        matrix.Data[origClassIndex, estClassIndex]++;
    9278      }
    9379    }
    94 
    95     private bool IsEqual(double x, double y) {
    96       return Math.Abs(x - y) < EPSILON;
    97     }
    9880  }
    9981}
  • trunk/sources/HeuristicLab.GP.StructureIdentification.Classification/HeuristicLab.GP.StructureIdentification.Classification.csproj

    r669 r702  
    6767    <Compile Include="ClassificationMeanSquaredErrorEvaluator.cs" />
    6868    <Compile Include="ConfusionMatrixEvaluator.cs" />
     69    <Compile Include="GPClassificationEvaluatorBase.cs" />
    6970    <Compile Include="CrossValidation.cs" />
    7071    <Compile Include="FunctionLibraryInjector.cs">
     
    7273    </Compile>
    7374    <Compile Include="HeuristicLabGPClassificationPlugin.cs" />
    74     <Compile Include="MCCEvaluator.cs" />
    7575    <Compile Include="MulticlassModeller.cs" />
    7676    <Compile Include="MulticlassOneVsOneAnalyzer.cs" />
  • trunk/sources/HeuristicLab.GP.StructureIdentification.Classification/MulticlassOneVsOneAnalyzer.cs

    r668 r702  
    7979
    8080        BakedTreeEvaluator evaluator = new BakedTreeEvaluator();
    81         evaluator.ResetEvaluator(functionTree, dataset);
     81        evaluator.ResetEvaluator(functionTree, dataset, targetVariable, samplesStart, samplesEnd, 1.0);
    8282
    8383        for(int i = 0; i < (samplesEnd - samplesStart); i++) {
Note: See TracChangeset for help on using the changeset viewer.