Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/16/12 15:19:40 (12 years ago)
Author:
sforsten
Message:

#1776:

  • improved performance of confidence calculation
  • fixed bug in median confidence calculation
  • fixed bug in average confidence calculation
  • confidence calculation is now easier for training and test
  • removed obsolete view ClassificationEnsembleSolutionConfidenceAccuracyDependence
Location:
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/AverageThresholdCalculator.cs

    r8534 r8814  
    5757    }
    5858
    59     protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) {
     59    protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
    6060      Dataset dataset = solutions.First().ProblemData.Dataset;
    61       IList<double> values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();
     61      IList<double> values = solutions.Where(s => handler(s.ProblemData, index)).Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();
    6262      if (values.Count <= 0)
    6363        return double.NaN;
     
    6666    }
    6767
    68     public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     68    public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
    6969      Dataset dataset = solutions.First().ProblemData.Dataset;
    70       double[][] values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, indices).ToArray()).ToArray();
     70      List<int> indicesList = indices.ToList();
     71      var solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedValues(dataset, indicesList).ToArray());
    7172      double[] confidences = new double[indices.Count()];
    7273      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
    7374
    74       for (int i = 0; i < indices.Count(); i++) {
    75         double avg = values.Select(x => x[i]).Average();
    76         confidences[i] = GetAverageConfidence(avg, estimatedClassValueArr[i]);
     75      for (int i = 0; i < indicesList.Count; i++) {
     76        var values = solValues.Where(x => handler(x.Key.ProblemData, indicesList[i])).Select(x => x.Value[i]);
     77        if (values.Count() <= 0) {
     78          confidences[i] = double.NaN;
     79        } else {
     80          double avg = values.Average();
     81          confidences[i] = GetAverageConfidence(avg, estimatedClassValueArr[i]);
     82        }
    7783      }
    7884
     
    8490        if (estimatedClassValue.Equals(classValues[i])) {
    8591          //special case: avgerage is higher than value of highest class
    86           if (i == classValues.Length - 1 && avg > estimatedClassValue) {
     92          if (i == classValues.Length - 1 && avg >= estimatedClassValue) {
    8793            return 1;
    8894          }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs

    r8811 r8814  
    2424using HeuristicLab.Common;
    2525using HeuristicLab.Core;
    26 using HeuristicLab.Data;
    2726using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2827
     
    112111    }
    113112
    114     public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue) {
     113    public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
    115114      if (solutions.Count() < 1)
    116115        return double.NaN;
    117116      Dataset dataset = solutions.First().ProblemData.Dataset;
    118117      var correctSolutions = solutions.Select(s => new { Solution = s, Values = s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First() })
    119                                       .Where(a => a.Values.Equals(estimatedClassValue))
     118                                      .Where(a => handler(a.Solution.ProblemData, index) && a.Values.Equals(estimatedClassValue))
    120119                                      .Select(a => a.Solution);
    121120      return (from sol in correctSolutions
     
    123122    }
    124123
    125     public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     124    public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
    126125      if (solutions.Count() < 1)
    127126        return Enumerable.Repeat(double.NaN, indices.Count());
    128127
     128      List<int> indicesList = indices.ToList();
     129
    129130      Dataset dataset = solutions.First().ProblemData.Dataset;
    130       Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indices).ToArray());
     131      Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indicesList).ToArray());
    131132      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
    132       double[] confidences = new double[indices.Count()];
     133      double[] confidences = new double[indicesList.Count];
    133134
    134       for (int i = 0; i < indices.Count(); i++) {
     135      for (int i = 0; i < indicesList.Count; i++) {
    135136        var correctSolutions = solValues.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i]));
    136137        confidences[i] = (from sol in correctSolutions
     138                          where handler(sol.Key.ProblemData, indicesList[i])
    137139                          select weights[sol.Key]).Sum();
    138140      }
     
    142144
    143145    #region Helper
    144     protected IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {
    145       return from i in indizes
    146              select targetValues[i];
    147     }
    148146    protected bool PointInTraining(IClassificationProblemData problemData, int point) {
    149       IntRange trainingPartition = problemData.TrainingPartition;
    150       IntRange testPartition = problemData.TestPartition;
    151       return (trainingPartition.Start <= point && point < trainingPartition.End)
    152         && !(testPartition.Start <= point && point < testPartition.End);
     147      return problemData.IsTrainingSample(point);
    153148    }
    154149    protected bool PointInTest(IClassificationProblemData problemData, int point) {
    155       IntRange testPartition = problemData.TestPartition;
    156       return testPartition.Start <= point && point < testPartition.End;
     150      return problemData.IsTestSample(point);
    157151    }
    158152    protected bool AllPoints(IClassificationProblemData problemData, int point) {
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs

    r8811 r8814  
    8282    }
    8383
    84     public sealed override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue) {
     84    public sealed override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
    8585      if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
    8686        return double.NaN;
     
    8888      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
    8989
    90       return GetDiscriminantConfidence(discriminantSolutions, index, estimatedClassValue);
     90      return GetDiscriminantConfidence(discriminantSolutions, index, estimatedClassValue, handler);
    9191    }
    9292
    93     protected virtual double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) {
    94       return base.GetConfidence(solutions, index, estimatedClassValue);
     93    protected virtual double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
     94      return base.GetConfidence(solutions, index, estimatedClassValue, handler);
    9595    }
    9696
    97     public sealed override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     97    public sealed override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
    9898      if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
    9999        return Enumerable.Repeat(double.NaN, indices.Count());
     
    101101      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
    102102
    103       return GetDiscriminantConfidence(discriminantSolutions, indices, estimatedClassValue);
     103      return GetDiscriminantConfidence(discriminantSolutions, indices, estimatedClassValue, handler);
    104104    }
    105105
    106     public virtual IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
    107       return base.GetConfidence(solutions, indices, estimatedClassValue);
     106    public virtual IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
     107      return base.GetConfidence(solutions, indices, estimatedClassValue, handler);
    108108    }
    109109  }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MajorityVoteWeightCalculator.cs

    r8297 r8814  
    5252    }
    5353
    54     public override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue) {
     54    public override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
    5555      if (solutions.Count() < 1)
    5656        return double.NaN;
    57       Dataset dataset = solutions.First().ProblemData.Dataset;
    58       int correctEstimated = solutions.Select(s => s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First())
     57      var votingSolutions = solutions.Where(s => handler(s.ProblemData, index));
     58      if (votingSolutions.Count() < 1)
     59        return double.NaN;
     60      Dataset dataset = votingSolutions.First().ProblemData.Dataset;
     61      int correctEstimated = votingSolutions.Select(s => s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First())
    5962                                      .Where(x => x.Equals(estimatedClassValue))
    6063                                      .Count();
    61       return ((double)correctEstimated / (double)solutions.Count() - 0.5) * 2;
     64      return ((double)correctEstimated / (double)votingSolutions.Count() - 0.5) * 2;
    6265    }
    6366
    64     public override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     67    public override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
    6568      if (solutions.Count() < 1)
    6669        return Enumerable.Repeat(double.NaN, indices.Count());
     70
     71      List<int> indicesList = indices.ToList();
    6772      Dataset dataset = solutions.First().ProblemData.Dataset;
    68       var estimationsPerSolution = solutions.Select(s => s.Model.GetEstimatedClassValues(dataset, indices).ToArray()).ToArray();
     73      var solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indicesList).ToArray());
    6974      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
    7075      double correctEstimated;
    7176      double[] confidences = new double[indices.Count()];
    7277
    73       for (int i = 0; i < indices.Count(); i++) {
    74         correctEstimated = estimationsPerSolution.Where(x => DoubleExtensions.IsAlmost(x[i], estimatedClassValueArr[i])).Count();
    75         confidences[i] = (correctEstimated / (double)solutions.Count() - 0.5) * 2;
     78      for (int i = 0; i < indicesList.Count; i++) {
     79        var votingSolutions = solValues.Where(x => handler(x.Key.ProblemData, indicesList[i]));
     80        correctEstimated = votingSolutions.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i])).Count();
     81        confidences[i] = (correctEstimated / (double)votingSolutions.Count() - 0.5) * 2;
    7682      }
    7783
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MedianThresholdCalculator.cs

    r8534 r8814  
    5757    }
    5858
    59     protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) {
    60 
     59    protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
    6160      Dataset dataset = solutions.First().ProblemData.Dataset;
    62       IList<double> values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();
     61      IList<double> values = solutions.Where(s => handler(s.ProblemData, index)).Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();
    6362      if (values.Count <= 0)
    6463        return double.NaN;
     
    6766    }
    6867
    69     public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     68    public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
    7069
    7170      Dataset dataset = solutions.First().ProblemData.Dataset;
    72       double[][] values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, indices).ToArray()).ToArray();
     71      List<int> indicesList = indices.ToList();
     72      var solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedValues(dataset, indicesList).ToArray());
    7373      double[] confidences = new double[indices.Count()];
    7474      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
    7575
    76       for (int i = 0; i < indices.Count(); i++) {
    77         double avg = values.Select(x => x[i]).Average();
    78         confidences[i] = GetMedianConfidence(avg, estimatedClassValueArr[i]);
     76      for (int i = 0; i < indicesList.Count; i++) {
     77        var values = solValues.Where(x => handler(x.Key.ProblemData, indicesList[i])).Select(x => x.Value[i]).ToList();
     78        if (values.Count <= 0) {
     79          confidences[i] = double.NaN;
     80        } else {
     81          double median = GetMedian(values);
     82          confidences[i] = GetMedianConfidence(median, estimatedClassValueArr[i]);
     83        }
    7984      }
    8085
     
    8590      for (int i = 0; i < classValues.Length; i++) {
    8691        if (estimatedClassValue.Equals(classValues[i])) {
    87           //special case: avgerage is higher than value of highest class
    88           if (i == classValues.Length - 1 && median > estimatedClassValue) {
     92          //special case: median is higher than value of highest class
     93          if (i == classValues.Length - 1 && median >= estimatedClassValue) {
    8994            return 1;
    9095          }
    91           //special case: average is lower than value of lowest class
     96          //special case: median is lower than value of lowest class
    9297          if (i == 0 && median < estimatedClassValue) {
    9398            return 1;
    9499          }
    95           //special case: average is not between threshold of estimated class value
     100          //special case: median is not between threshold of estimated class value
    96101          if ((i < classValues.Length - 1 && median >= threshold[i + 1]) || median <= threshold[i]) {
    97102            return 0;
    98103          }
    99104
    100           double thresholdToClassDistance, thresholdToAverageValueDistance;
     105          double thresholdToClassDistance, thresholdToMedianValueDistance;
    101106          if (median >= classValues[i]) {
    102107            thresholdToClassDistance = threshold[i + 1] - classValues[i];
    103             thresholdToAverageValueDistance = threshold[i + 1] - median;
     108            thresholdToMedianValueDistance = threshold[i + 1] - median;
    104109          } else {
    105110            thresholdToClassDistance = classValues[i] - threshold[i];
    106             thresholdToAverageValueDistance = median - threshold[i];
     111            thresholdToMedianValueDistance = median - threshold[i];
    107112          }
    108           return (1 / thresholdToClassDistance) * thresholdToAverageValueDistance;
     113          return (1 / thresholdToClassDistance) * thresholdToMedianValueDistance;
    109114        }
    110115      }
Note: See TracChangeset for help on using the changeset viewer.