Free cookie consent management tool by TermsFeed Policy Generator

Changeset 8297


Ignore:
Timestamp:
07/17/12 15:30:04 (11 years ago)
Author:
sforsten
Message:

#1776:

  • Corrected namespace of IClassificationEnsembleSolutionWeightCalculator interface
  • Corrected calculation of confidence for test and training samples in ClassificationEnsembleSolutionEstimatedClassValuesView
  • Added overload method GetConfidence to IClassificationEnsembleSolutionWeightCalculator to calculate more than one point at a time (maybe additional methods for training and test confidence could improve the performance remarkably)
  • Added ClassificationEnsembleSolutionConfidenceAccuracyDependence to see how accuracy would behave, if only samples with high confidence would be classified
Location:
branches/ClassificationEnsembleVoting
Files:
2 added
10 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationEnsembleSolutionEstimatedClassValuesView.cs

    r8101 r8297  
    2828using HeuristicLab.MainForm;
    2929using HeuristicLab.MainForm.WindowsForms;
    30 using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
     30using HeuristicLab.Problems.DataAnalysis.Interfaces;
    3131
    3232namespace HeuristicLab.Problems.DataAnalysis.Views {
     
    111111      double curConfidence;
    112112
     113      double[] confidences = null;
     114      if (SamplesComboBox.SelectedItem.ToString() == SamplesComboBoxAllSamples) {
     115        confidences = weightCalc.GetConfidence(solutions, indizes, estimatedClassValues).ToArray();
     116      }
     117
    113118      for (int i = 0; i < indizes.Length; i++) {
    114119        int row = indizes[i];
     
    120125          correctClassified = target[i].IsAlmost(estimatedClassValues[i]);
    121126          values[i, 3] = correctClassified.ToString();
    122           curConfidence = weightCalc.GetConfidence(solutions, indizes[i], estimatedClassValues[i]);
     127          if (SamplesComboBox.SelectedItem.ToString() == SamplesComboBoxAllSamples) {
     128            curConfidence = confidences[i];
     129          } else {
     130            curConfidence = weightCalc.GetConfidence(GetRelevantSolutions(SamplesComboBox.SelectedItem.ToString(), solutions, row),
     131                                                     indizes[i], estimatedClassValues[i]);
     132          }
    123133          if (correctClassified) {
    124134            confidence[0] += curConfidence;
     
    156166      matrix.SortableView = true;
    157167      matrixView.Content = matrix;
     168    }
     169
     170    protected IEnumerable<IClassificationSolution> GetRelevantSolutions(string samplesSelection, IEnumerable<IClassificationSolution> solutions, int curRow) {
     171      if (samplesSelection == SamplesComboBoxAllSamples)
     172        return solutions;
     173      else if (samplesSelection == SamplesComboBoxTrainingSamples)
     174        return solutions.Where(s => s.ProblemData.IsTrainingSample(curRow));
     175      else if (samplesSelection == SamplesComboBoxTestSamples)
     176        return solutions.Where(s => s.ProblemData.IsTestSample(curRow));
     177      else
     178        return new List<IClassificationSolution>();
    158179    }
    159180
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/HeuristicLab.Problems.DataAnalysis.Views-3.4.csproj

    r7866 r8297  
    154154      <DependentUpon>ClassificationEnsembleSolutionModelView.cs</DependentUpon>
    155155    </Compile>
     156    <Compile Include="Classification\ClassificationEnsembleSolutionConfidenceAccuracyDependence.cs">
     157      <SubType>UserControl</SubType>
     158    </Compile>
     159    <Compile Include="Classification\ClassificationEnsembleSolutionConfidenceAccuracyDependence.Designer.cs">
     160      <DependentUpon>ClassificationEnsembleSolutionConfidenceAccuracyDependence.cs</DependentUpon>
     161    </Compile>
    156162    <Compile Include="DataAnalysisSolutionEvaluationView.cs">
    157163      <SubType>UserControl</SubType>
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/Solution Views/ClassificationEnsembleSolutionView.cs

    r7866 r8297  
    2626using HeuristicLab.MainForm;
    2727using HeuristicLab.PluginInfrastructure;
    28 using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
     28using HeuristicLab.Problems.DataAnalysis.Interfaces;
    2929
    3030namespace HeuristicLab.Problems.DataAnalysis.Views {
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r8177 r8297  
    2828using HeuristicLab.Data;
    2929using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    30 using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
     30using HeuristicLab.Problems.DataAnalysis.Interfaces;
    3131
    3232namespace HeuristicLab.Problems.DataAnalysis {
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/AverageThresholdCalculator.cs

    r8101 r8297  
    6666        return double.NaN;
    6767      double avg = values.Average();
     68      return GetAverageConfidence(avg, estimatedClassValue);
     69    }
     70
     71    public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     72      if (!classValues.Count().Equals(2))
     73        return Enumerable.Repeat(double.NaN, indices.Count());
     74
     75      Dataset dataset = solutions.First().ProblemData.Dataset;
     76      double[][] values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, indices).ToArray()).ToArray();
     77      double[] confidences = new double[indices.Count()];
     78      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
     79
     80      for (int i = 0; i < indices.Count(); i++) {
     81        double avg = values.Select(x => x[i]).Average();
     82        confidences[i] = GetAverageConfidence(avg, estimatedClassValueArr[i]);
     83      }
     84
     85      return confidences;
     86    }
     87
     88    protected double GetAverageConfidence(double avg, double estimatedClassValue) {
    6889      if (estimatedClassValue.Equals(classValues[0])) {
    6990        if (avg < estimatedClassValue)
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs

    r7562 r8297  
    2626using HeuristicLab.Data;
    2727using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    28 using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
     28using HeuristicLab.Problems.DataAnalysis.Interfaces;
    2929
    3030namespace HeuristicLab.Problems.DataAnalysis {
     
    124124    }
    125125
     126    public virtual IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     127      if (solutions.Count() < 1)
     128        return Enumerable.Repeat(double.NaN, indices.Count());
     129
     130      Dataset dataset = solutions.First().ProblemData.Dataset;
     131      Dictionary<IClassificationSolution, double[]> solValues = solutions.ToDictionary(x => x, x => x.Model.GetEstimatedClassValues(dataset, indices).ToArray());
     132      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
     133      double[] confidences = new double[indices.Count()];
     134
     135      for (int i = 0; i < indices.Count(); i++) {
     136        var correctSolutions = solValues.Where(x => DoubleExtensions.IsAlmost(x.Value[i], estimatedClassValueArr[i]));
     137        confidences[i] = (from sol in correctSolutions
     138                          select weights[sol.Key]).Sum();
     139      }
     140
     141      return confidences;
     142    }
     143
    126144    #region Helper
    127145    protected IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs

    r8177 r8297  
    2424using HeuristicLab.Common;
    2525using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    26 using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
     26using HeuristicLab.Problems.DataAnalysis.Interfaces;
    2727
    2828namespace HeuristicLab.Problems.DataAnalysis {
     
    9595      return base.GetConfidence(solutions, index, estimatedClassValue);
    9696    }
     97
     98    public sealed override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     99      if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
     100        return Enumerable.Repeat(double.NaN, indices.Count());
     101
     102      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
     103
     104      return GetDiscriminantConfidence(discriminantSolutions, indices, estimatedClassValue);
     105    }
     106
     107    public virtual IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     108      return base.GetConfidence(solutions, indices, estimatedClassValue);
     109    }
    97110  }
    98111}
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MajorityVoteWeightCalculator.cs

    r7866 r8297  
    6161      return ((double)correctEstimated / (double)solutions.Count() - 0.5) * 2;
    6262    }
     63
     64    public override IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     65      if (solutions.Count() < 1)
     66        return Enumerable.Repeat(double.NaN, indices.Count());
     67      Dataset dataset = solutions.First().ProblemData.Dataset;
     68      var estimationsPerSolution = solutions.Select(s => s.Model.GetEstimatedClassValues(dataset, indices).ToArray()).ToArray();
     69      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
     70      double correctEstimated;
     71      double[] confidences = new double[indices.Count()];
     72
     73      for (int i = 0; i < indices.Count(); i++) {
     74        correctEstimated = estimationsPerSolution.Where(x => DoubleExtensions.IsAlmost(x[i], estimatedClassValueArr[i])).Count();
     75        confidences[i] = (correctEstimated / (double)solutions.Count() - 0.5) * 2;
     76      }
     77
     78      return confidences;
     79    }
    6380  }
    6481}
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MedianThresholdCalculator.cs

    r8101 r8297  
    6666        return double.NaN;
    6767      double median = GetMedian(values);
     68      return GetMedianConfidence(median, estimatedClassValue);
     69    }
     70
     71    public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
     72      if (!classValues.Count().Equals(2))
     73        return Enumerable.Repeat(double.NaN, indices.Count());
     74
     75      Dataset dataset = solutions.First().ProblemData.Dataset;
     76      double[][] values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, indices).ToArray()).ToArray();
     77      double[] confidences = new double[indices.Count()];
     78      double[] estimatedClassValueArr = estimatedClassValue.ToArray();
     79
     80      for (int i = 0; i < indices.Count(); i++) {
     81        double avg = values.Select(x => x[i]).Average();
     82        confidences[i] = GetMedianConfidence(avg, estimatedClassValueArr[i]);
     83      }
     84
     85      return confidences;
     86    }
     87
     88    protected double GetMedianConfidence(double median, double estimatedClassValue) {
    6889      if (estimatedClassValue.Equals(classValues[0])) {
    6990        if (median < estimatedClassValue)
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationEnsembleSolutionWeightCalculator.cs

    r7562 r8297  
    2323using HeuristicLab.Core;
    2424
    25 namespace HeuristicLab.Problems.DataAnalysis.Interfaces.Classification {
     25namespace HeuristicLab.Problems.DataAnalysis.Interfaces {
    2626  public delegate bool CheckPoint(IClassificationProblemData problemData, int point);
    2727
     
    3030    IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler);
    3131    double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue);
     32    IEnumerable<double> GetConfidence(IEnumerable<IClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue);
    3233
    3334    CheckPoint GetTestClassDelegate();
Note: See TracChangeset for help on using the changeset viewer.