Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/16/12 15:19:40 (12 years ago)
Author:
sforsten
Message:

#1776:

  • improved performance of confidence calculation
  • fixed bug in median confidence calculation
  • fixed bug in average confidence calculation
  • confidence calculation is now easier for training and test
  • removed obsolete view ClassificationEnsembleSolutionConfidenceAccuracyDependence
Location:
branches/ClassificationEnsembleVoting
Files:
2 deleted
10 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationEnsembleSolutionAccuracyToCoveredSamples.cs

    r8811 r8814  
    2626using System.Windows.Forms;
    2727using System.Windows.Forms.DataVisualization.Charting;
    28 using HeuristicLab.Data;
    2928using HeuristicLab.MainForm;
    3029
     
    8281        IClassificationEnsembleSolutionWeightCalculator weightCalc = Content.WeightCalculator;
    8382        var solutions = Content.ClassificationSolutions;
    84         double[] estimatedClassValues;
     83        double[] estimatedClassValues = null;
    8584        double[] classValues;
    8685        OnlineAccuracyCalculator accuracyCalc = new OnlineAccuracyCalculator();
    8786
    88         int rows;
    89         double[] confidences;
     87        int rows = 0;
     88        double[] confidences = null;
     89
     90        classValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToArray();
    9091
    9192        if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxAllSamples)) {
    9293          rows = Content.ProblemData.Dataset.Rows;
    9394          estimatedClassValues = Content.EstimatedClassValues.ToArray();
    94           classValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToArray();
    95           confidences = weightCalc.GetConfidence(solutions, Enumerable.Range(0, rows), estimatedClassValues).ToArray();
    96         } else {
    97           IntRange range;
    98           if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTrainingSamples)) {
    99             range = Content.ProblemData.TrainingPartition;
    100             estimatedClassValues = Content.EstimatedTrainingClassValues.ToArray();
    101           } else if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTestSamples)) {
    102             range = Content.ProblemData.TestPartition;
    103             estimatedClassValues = Content.EstimatedTestClassValues.ToArray();
    104           } else {
    105             return;
    106           }
    107           rows = range.End - range.Start;
    108           classValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable)
    109                                       .Skip(range.Start).Take(range.End - range.Start).ToArray();
    110           confidences = new double[rows];
    111           int index;
    112           for (int i = 0; i < rows; i++) {
    113             index = range.Start + i;
    114             confidences[i] = weightCalc.GetConfidence(GetRelevantSolutions(SamplesComboBox.SelectedItem.ToString(), solutions, index),
    115                                                       index, estimatedClassValues[i]);
    116           }
     95          confidences = weightCalc.GetConfidence(solutions,
     96                                                 Enumerable.Range(0, Content.ProblemData.Dataset.Rows),
     97                                                 estimatedClassValues,
     98                                                 weightCalc.GetAllClassDelegate()).ToArray();
     99        } else if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTrainingSamples)) {
     100          rows = Content.ProblemData.TrainingIndices.Count();
     101          estimatedClassValues = Content.EstimatedTrainingClassValues.ToArray();
     102          confidences = weightCalc.GetConfidence(solutions,
     103                                                 Content.ProblemData.TrainingIndices,
     104                                                 estimatedClassValues,
     105                                                 weightCalc.GetTrainingClassDelegate()).ToArray();
     106        } else if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTestSamples)) {
     107          rows = Content.ProblemData.TestIndices.Count();
     108          estimatedClassValues = Content.EstimatedTestClassValues.ToArray();
     109          confidences = weightCalc.GetConfidence(solutions,
     110                                                 Content.ProblemData.TestIndices,
     111                                                 estimatedClassValues,
     112                                                 weightCalc.GetTestClassDelegate()).ToArray();
    117113        }
    118114
     
    130126
    131127          accuracy[i + 1] = accuracyCalc.Accuracy;
    132           covered[i] = 1.0 - (double)notCovered / (double)rows;
     128          if (rows > 0) {
     129            covered[i] = 1.0 - (double)notCovered / (double)rows;
     130          }
    133131          accuracyCalc.Reset();
    134132        }
     
    179177    }
    180178
    181     protected IEnumerable<IClassificationSolution> GetRelevantSolutions(string samplesSelection, IEnumerable<IClassificationSolution> solutions, int curRow) {
    182       if (samplesSelection == SamplesComboBoxAllSamples)
    183         return solutions;
    184       else if (samplesSelection == SamplesComboBoxTrainingSamples)
    185         return solutions.Where(s => s.ProblemData.IsTrainingSample(curRow));
    186       else if (samplesSelection == SamplesComboBoxTestSamples)
    187         return solutions.Where(s => s.ProblemData.IsTestSample(curRow));
    188       else
    189         return new List<IClassificationSolution>();
    190     }
    191 
    192179    private double CalculateAreaUnderCurve(Series series) {
    193180      if (series.Points.Count < 1) throw new ArgumentException("Could not calculate area under curve if less than 1 data points were given.");
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationEnsembleSolutionEstimatedClassValuesView.cs

    r8811 r8814  
    112112
    113113      double[] confidences = null;
    114       if (SamplesComboBox.SelectedItem.ToString() == SamplesComboBoxAllSamples) {
    115         confidences = weightCalc.GetConfidence(solutions, indizes, estimatedClassValues).ToArray();
     114
     115      if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxAllSamples)) {
     116        confidences = weightCalc.GetConfidence(solutions,
     117                                               indizes,
     118                                               estimatedClassValues,
     119                                               weightCalc.GetAllClassDelegate()).ToArray();
     120      } else if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTrainingSamples)) {
     121        confidences = weightCalc.GetConfidence(solutions,
     122                                               indizes,
     123                                               estimatedClassValues,
     124                                               weightCalc.GetTrainingClassDelegate()).ToArray();
     125      } else if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTestSamples)) {
     126        confidences = weightCalc.GetConfidence(solutions,
     127                                               indizes,
     128                                               estimatedClassValues,
     129                                               weightCalc.GetTestClassDelegate()).ToArray();
    116130      }
    117131
     
    125139          correctClassified = target[i].IsAlmost(estimatedClassValues[i]);
    126140          values[i, 3] = correctClassified.ToString();
    127           if (SamplesComboBox.SelectedItem.ToString() == SamplesComboBoxAllSamples) {
    128             curConfidence = confidences[i];
    129           } else {
    130             curConfidence = weightCalc.GetConfidence(GetRelevantSolutions(SamplesComboBox.SelectedItem.ToString(), solutions, row),
    131                                                      indizes[i], estimatedClassValues[i]);
    132           }
     141          curConfidence = confidences[i];
    133142          if (correctClassified) {
    134143            confidence[0] += curConfidence;
     
    167176      matrix.SortableView = true;
    168177      matrixView.Content = matrix;
    169     }
    170 
    171     protected IEnumerable<IClassificationSolution> GetRelevantSolutions(string samplesSelection, IEnumerable<IClassificationSolution> solutions, int curRow) {
    172       if (samplesSelection == SamplesComboBoxAllSamples)
    173         return solutions;
    174       else if (samplesSelection == SamplesComboBoxTrainingSamples)
    175         return solutions.Where(s => s.ProblemData.IsTrainingSample(curRow));
    176       else if (samplesSelection == SamplesComboBoxTestSamples)
    177         return solutions.Where(s => s.ProblemData.IsTestSample(curRow));
    178       else
    179         return new List<IClassificationSolution>();
    180     }
    181 
    182     private IEnumerable<int> FindAllIndices(List<double?> list, double value) {
    183       List<int> indices = new List<int>();
    184       for (int i = 0; i < list.Count; i++) {
    185         if (list[i].Equals(value))
    186           indices.Add(i);
    187       }
    188       return indices;
    189178    }
    190179
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/HeuristicLab.Problems.DataAnalysis.Views-3.4.csproj

    r8811 r8814  
    214214    <Compile Include="FeatureCorrelation\AbstractFeatureCorrelationView.Designer.cs">
    215215      <DependentUpon>AbstractFeatureCorrelationView.cs</DependentUpon>
    216     </Compile>
    217     <Compile Include="Classification\ClassificationEnsembleSolutionConfidenceAccuracyDependence.cs">
    218       <SubType>UserControl</SubType>
    219     </Compile>
    220     <Compile Include="Classification\ClassificationEnsembleSolutionConfidenceAccuracyDependence.Designer.cs">
    221       <DependentUpon>ClassificationEnsembleSolutionConfidenceAccuracyDependence.cs</DependentUpon>
    222216    </Compile>
    223217    <Compile Include="Classification\ClassificationEnsembleSolutionAccuracyToCoveredSamples.cs">
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/HeuristicLab.Problems.DataAnalysis-3.4.csproj

    r8811 r8814  
    9494  <ItemGroup>
    9595    <Reference Include="ALGLIB-3.6.0, Version=3.6.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
    96       <HintPath>>..\..\..\..\trunk\sources\bin\ALGLIB-3.6.0.dll</HintPath>
     96      <HintPath>&gt;..\..\..\..\trunk\sources\bin\ALGLIB-3.6.0.dll</HintPath>
    9797      <Private>False</Private>
    9898    </Reference>
    9999    <Reference Include="HeuristicLab.ALGLIB-3.6.0, Version=3.6.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
    100       <HintPath>>..\..\..\..\trunk\sources\bin\HeuristicLab.ALGLIB-3.6.0.dll</HintPath>
     100      <HintPath>&gt;..\..\..\..\trunk\sources\bin\HeuristicLab.ALGLIB-3.6.0.dll</HintPath>
    101101      <Private>False</Private>
    102102    </Reference>
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/AverageThresholdCalculator.cs

    r8534 r8814  
    5757    }
    5858
    59     protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) {
     59    protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
    6060      Dataset dataset = solutions.First().ProblemData.Dataset;
    61       IList<double> values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();
     61      IList<double> values = solutions.Where(s => handler(s.ProblemData, index)).Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();
    6262      if (values.Count <= 0)
    6363        return double.NaN;
     
    6666    }
    6767
    68     public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     68    public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
    6969      Dataset dataset = solutions.First().ProblemData.Dataset;
    70       double[][] values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, indices).ToArray()).ToArray();
     70      List<int> indicesList = indices.ToList();
     71      var solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedValues(dataset, indicesList).ToArray());
    7172      double[] confidences = new double[indices.Count()];
    7273      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
    7374
    74       for (int i = 0; i < indices.Count(); i++) {
    75         double avg = values.Select(x => x[i]).Average();
    76         confidences[i] = GetAverageConfidence(avg, estimatedClassValueArr[i]);
     75      for (int i = 0; i < indicesList.Count; i++) {
     76        var values = solValues.Where(x => handler(x.Key.ProblemData, indicesList[i])).Select(x => x.Value[i]);
     77        if (values.Count() <= 0) {
     78          confidences[i] = double.NaN;
     79        } else {
     80          double avg = values.Average();
     81          confidences[i] = GetAverageConfidence(avg, estimatedClassValueArr[i]);
     82        }
    7783      }
    7884
     
    8490        if (estimatedClassValue.Equals(classValues[i])) {
    8591          //special case: avgerage is higher than value of highest class
    86           if (i == classValues.Length - 1 && avg > estimatedClassValue) {
     92          if (i == classValues.Length - 1 && avg >= estimatedClassValue) {
    8793            return 1;
    8894          }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs

    r8811 r8814  
    2424using HeuristicLab.Common;
    2525using HeuristicLab.Core;
    26 using HeuristicLab.Data;
    2726using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2827
     
    112111    }
    113112
    114     public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue) {
     113    public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
    115114      if (solutions.Count() < 1)
    116115        return double.NaN;
    117116      Dataset dataset = solutions.First().ProblemData.Dataset;
    118117      var correctSolutions = solutions.Select(s => new { Solution = s, Values = s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First() })
    119                                       .Where(a => a.Values.Equals(estimatedClassValue))
     118                                      .Where(a => handler(a.Solution.ProblemData, index) && a.Values.Equals(estimatedClassValue))
    120119                                      .Select(a => a.Solution);
    121120      return (from sol in correctSolutions
     
    123122    }
    124123
    125     public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     124    public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
    126125      if (solutions.Count() < 1)
    127126        return Enumerable.Repeat(double.NaN, indices.Count());
    128127
     128      List<int> indicesList = indices.ToList();
     129
    129130      Dataset dataset = solutions.First().ProblemData.Dataset;
    130       Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indices).ToArray());
     131      Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indicesList).ToArray());
    131132      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
    132       double[] confidences = new double[indices.Count()];
     133      double[] confidences = new double[indicesList.Count];
    133134
    134       for (int i = 0; i < indices.Count(); i++) {
     135      for (int i = 0; i < indicesList.Count; i++) {
    135136        var correctSolutions = solValues.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i]));
    136137        confidences[i] = (from sol in correctSolutions
     138                          where handler(sol.Key.ProblemData, indicesList[i])
    137139                          select weights[sol.Key]).Sum();
    138140      }
     
    142144
    143145    #region Helper
    144     protected IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {
    145       return from i in indizes
    146              select targetValues[i];
    147     }
    148146    protected bool PointInTraining(IClassificationProblemData problemData, int point) {
    149       IntRange trainingPartition = problemData.TrainingPartition;
    150       IntRange testPartition = problemData.TestPartition;
    151       return (trainingPartition.Start <= point && point < trainingPartition.End)
    152         && !(testPartition.Start <= point && point < testPartition.End);
     147      return problemData.IsTrainingSample(point);
    153148    }
    154149    protected bool PointInTest(IClassificationProblemData problemData, int point) {
    155       IntRange testPartition = problemData.TestPartition;
    156       return testPartition.Start <= point && point < testPartition.End;
     150      return problemData.IsTestSample(point);
    157151    }
    158152    protected bool AllPoints(IClassificationProblemData problemData, int point) {
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs

    r8811 r8814  
    8282    }
    8383
    84     public sealed override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue) {
     84    public sealed override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
    8585      if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
    8686        return double.NaN;
     
    8888      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
    8989
    90       return GetDiscriminantConfidence(discriminantSolutions, index, estimatedClassValue);
     90      return GetDiscriminantConfidence(discriminantSolutions, index, estimatedClassValue, handler);
    9191    }
    9292
    93     protected virtual double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) {
    94       return base.GetConfidence(solutions, index, estimatedClassValue);
     93    protected virtual double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
     94      return base.GetConfidence(solutions, index, estimatedClassValue, handler);
    9595    }
    9696
    97     public sealed override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     97    public sealed override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
    9898      if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
    9999        return Enumerable.Repeat(double.NaN, indices.Count());
     
    101101      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
    102102
    103       return GetDiscriminantConfidence(discriminantSolutions, indices, estimatedClassValue);
     103      return GetDiscriminantConfidence(discriminantSolutions, indices, estimatedClassValue, handler);
    104104    }
    105105
    106     public virtual IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
    107       return base.GetConfidence(solutions, indices, estimatedClassValue);
     106    public virtual IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
     107      return base.GetConfidence(solutions, indices, estimatedClassValue, handler);
    108108    }
    109109  }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MajorityVoteWeightCalculator.cs

    r8297 r8814  
    5252    }
    5353
    54     public override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue) {
     54    public override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
    5555      if (solutions.Count() < 1)
    5656        return double.NaN;
    57       Dataset dataset = solutions.First().ProblemData.Dataset;
    58       int correctEstimated = solutions.Select(s => s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First())
     57      var votingSolutions = solutions.Where(s => handler(s.ProblemData, index));
     58      if (votingSolutions.Count() < 1)
     59        return double.NaN;
     60      Dataset dataset = votingSolutions.First().ProblemData.Dataset;
     61      int correctEstimated = votingSolutions.Select(s => s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First())
    5962                                      .Where(x => x.Equals(estimatedClassValue))
    6063                                      .Count();
    61       return ((double)correctEstimated / (double)solutions.Count() - 0.5) * 2;
     64      return ((double)correctEstimated / (double)votingSolutions.Count() - 0.5) * 2;
    6265    }
    6366
    64     public override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     67    public override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
    6568      if (solutions.Count() < 1)
    6669        return Enumerable.Repeat(double.NaN, indices.Count());
     70
     71      List<int> indicesList = indices.ToList();
    6772      Dataset dataset = solutions.First().ProblemData.Dataset;
    68       var estimationsPerSolution = solutions.Select(s => s.Model.GetEstimatedClassValues(dataset, indices).ToArray()).ToArray();
     73      var solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indicesList).ToArray());
    6974      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
    7075      double correctEstimated;
    7176      double[] confidences = new double[indices.Count()];
    7277
    73       for (int i = 0; i < indices.Count(); i++) {
    74         correctEstimated = estimationsPerSolution.Where(x => DoubleExtensions.IsAlmost(x[i], estimatedClassValueArr[i])).Count();
    75         confidences[i] = (correctEstimated / (double)solutions.Count() - 0.5) * 2;
     78      for (int i = 0; i < indicesList.Count; i++) {
     79        var votingSolutions = solValues.Where(x => handler(x.Key.ProblemData, indicesList[i]));
     80        correctEstimated = votingSolutions.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i])).Count();
     81        confidences[i] = (correctEstimated / (double)votingSolutions.Count() - 0.5) * 2;
    7682      }
    7783
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MedianThresholdCalculator.cs

    r8534 r8814  
    5757    }
    5858
    59     protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) {
    60 
     59    protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler) {
    6160      Dataset dataset = solutions.First().ProblemData.Dataset;
    62       IList<double> values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();
     61      IList<double> values = solutions.Where(s => handler(s.ProblemData, index)).Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();
    6362      if (values.Count <= 0)
    6463        return double.NaN;
     
    6766    }
    6867
    69     public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     68    public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler) {
    7069
    7170      Dataset dataset = solutions.First().ProblemData.Dataset;
    72       double[][] values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, indices).ToArray()).ToArray();
     71      List<int> indicesList = indices.ToList();
     72      var solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedValues(dataset, indicesList).ToArray());
    7373      double[] confidences = new double[indices.Count()];
    7474      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
    7575
    76       for (int i = 0; i < indices.Count(); i++) {
    77         double avg = values.Select(x => x[i]).Average();
    78         confidences[i] = GetMedianConfidence(avg, estimatedClassValueArr[i]);
     76      for (int i = 0; i < indicesList.Count; i++) {
     77        var values = solValues.Where(x => handler(x.Key.ProblemData, indicesList[i])).Select(x => x.Value[i]).ToList();
     78        if (values.Count <= 0) {
     79          confidences[i] = double.NaN;
     80        } else {
     81          double median = GetMedian(values);
     82          confidences[i] = GetMedianConfidence(median, estimatedClassValueArr[i]);
     83        }
    7984      }
    8085
     
    8590      for (int i = 0; i < classValues.Length; i++) {
    8691        if (estimatedClassValue.Equals(classValues[i])) {
    87           //special case: avgerage is higher than value of highest class
    88           if (i == classValues.Length - 1 && median > estimatedClassValue) {
     92          //special case: median is higher than value of highest class
     93          if (i == classValues.Length - 1 && median >= estimatedClassValue) {
    8994            return 1;
    9095          }
    91           //special case: average is lower than value of lowest class
     96          //special case: median is lower than value of lowest class
    9297          if (i == 0 && median < estimatedClassValue) {
    9398            return 1;
    9499          }
    95           //special case: average is not between threshold of estimated class value
     100          //special case: median is not between threshold of estimated class value
    96101          if ((i < classValues.Length - 1 && median >= threshold[i + 1]) || median <= threshold[i]) {
    97102            return 0;
    98103          }
    99104
    100           double thresholdToClassDistance, thresholdToAverageValueDistance;
     105          double thresholdToClassDistance, thresholdToMedianValueDistance;
    101106          if (median >= classValues[i]) {
    102107            thresholdToClassDistance = threshold[i + 1] - classValues[i];
    103             thresholdToAverageValueDistance = threshold[i + 1] - median;
     108            thresholdToMedianValueDistance = threshold[i + 1] - median;
    104109          } else {
    105110            thresholdToClassDistance = classValues[i] - threshold[i];
    106             thresholdToAverageValueDistance = median - threshold[i];
     111            thresholdToMedianValueDistance = median - threshold[i];
    107112          }
    108           return (1 / thresholdToClassDistance) * thresholdToAverageValueDistance;
     113          return (1 / thresholdToClassDistance) * thresholdToMedianValueDistance;
    109114        }
    110115      }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationEnsembleSolutionWeightCalculator.cs

    r8811 r8814  
    2929    void CalculateNormalizedWeights(IEnumerable<IClassificationSolution> classificationSolutions);
    3030    IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler);
    31     double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue);
    32     IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue);
     31    double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue, CheckPoint handler);
     32    IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue, CheckPoint handler);
    3333
    3434    CheckPoint GetTestClassDelegate();
Note: See TracChangeset for help on using the changeset viewer.