Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/16/12 15:19:40 (12 years ago)
Author:
sforsten
Message:

#1776:

  • improved performance of confidence calculation
  • fixed bug in median confidence calculation
  • fixed bug in average confidence calculation
  • confidence calculation is now easier for training and test
  • removed obsolete view ClassificationEnsembleSolutionConfidenceAccuracyDependence
File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs

    r8811 r8814  
    2424using HeuristicLab.Common;
    2525using HeuristicLab.Core;
    26 using HeuristicLab.Data;
    2726using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2827
     
    112111    }
    113112
    114     public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue) {
     113    public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
    115114      if (solutions.Count() < 1)
    116115        return double.NaN;
    117116      Dataset dataset = solutions.First().ProblemData.Dataset;
    118117      var correctSolutions = solutions.Select(s => new { Solution = s, Values = s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First() })
    119                                       .Where(a => a.Values.Equals(estimatedClassValue))
     118                                      .Where(a => handler(a.Solution.ProblemData, index) && a.Values.Equals(estimatedClassValue))
    120119                                      .Select(a => a.Solution);
    121120      return (from sol in correctSolutions
     
    123122    }
    124123
    125     public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     124    public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
    126125      if (solutions.Count() < 1)
    127126        return Enumerable.Repeat(double.NaN, indices.Count());
    128127
     128      List<int> indicesList = indices.ToList();
     129
    129130      Dataset dataset = solutions.First().ProblemData.Dataset;
    130       Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indices).ToArray());
     131      Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indicesList).ToArray());
    131132      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
    132       double[] confidences = new double[indices.Count()];
     133      double[] confidences = new double[indicesList.Count];
    133134
    134       for (int i = 0; i < indices.Count(); i++) {
     135      for (int i = 0; i < indicesList.Count; i++) {
    135136        var correctSolutions = solValues.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i]));
    136137        confidences[i] = (from sol in correctSolutions
     138                          where handler(sol.Key.ProblemData, indicesList[i])
    137139                          select weights[sol.Key]).Sum();
    138140      }
     
    142144
    143145    #region Helper
    144     protected IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {
    145       return from i in indizes
    146              select targetValues[i];
    147     }
    148146    protected bool PointInTraining(IClassificationProblemData problemData, int point) {
    149       IntRange trainingPartition = problemData.TrainingPartition;
    150       IntRange testPartition = problemData.TestPartition;
    151       return (trainingPartition.Start <= point && point < trainingPartition.End)
    152         && !(testPartition.Start <= point && point < testPartition.End);
     147      return problemData.IsTrainingSample(point);
    153148    }
    154149    protected bool PointInTest(IClassificationProblemData problemData, int point) {
    155       IntRange testPartition = problemData.TestPartition;
    156       return testPartition.Start <= point && point < testPartition.End;
     150      return problemData.IsTestSample(point);
    157151    }
    158152    protected bool AllPoints(IClassificationProblemData problemData, int point) {
Note: See TracChangeset for help on using the changeset viewer.