Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/28/12 16:13:02 (12 years ago)
Author:
sforsten
Message:

#1776:

  • merged r8508:8533 from trunk into branch
  • AverageThresholdCalculator and MedianThresholdCalculator can now handle multi class classification
  • changed combo boxes in ClassificationEnsembleSolutionView to drop down list
Location:
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis
Files:
11 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis

  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleModel.cs

    r7259 r8534  
    9595
    9696    IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
    97       return new ClassificationEnsembleSolution(models, problemData);
     97      return new ClassificationEnsembleSolution(models, new ClassificationEnsembleProblemData(problemData));
    9898    }
    9999    #endregion
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r8508 r8534  
    114114    }
    115115
     116    public ClassificationEnsembleSolution(IClassificationProblemData problemData) :
     117      this(Enumerable.Empty<IClassificationModel>(), problemData) { }
     118
    116119    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData)
    117120      : this(models, problemData,
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationProblemData.cs

    r8508 r8534  
    277277
    278278    public ClassificationProblemData() : this(defaultDataset, defaultAllowedInputVariables, defaultTargetVariable) { }
     279
     280    public ClassificationProblemData(IClassificationProblemData classificationProblemData)
     281      : this(classificationProblemData.Dataset, classificationProblemData.AllowedInputVariables, classificationProblemData.TargetVariable) {
     282      TrainingPartition.Start = classificationProblemData.TrainingPartition.Start;
     283      TrainingPartition.End = classificationProblemData.TrainingPartition.End;
     284      TestPartition.Start = classificationProblemData.TestPartition.Start;
     285      TestPartition.End = classificationProblemData.TestPartition.End;
     286    }
     287
    279288    public ClassificationProblemData(Dataset dataset, IEnumerable<string> allowedInputVariables, string targetVariable)
    280289      : base(dataset, allowedInputVariables) {
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs

    r8508 r8534  
    5151      valueEvaluationCache = new Dictionary<int, double>();
    5252      classValueEvaluationCache = new Dictionary<int, double>();
    53 
    54       SetAccuracyMaximizingThresholds();
    5553    }
    5654
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolutionBase.cs

    r8508 r8534  
    9696    protected override void OnModelChanged() {
    9797      DeregisterEventHandler();
    98       SetAccuracyMaximizingThresholds();
    9998      RegisterEventHandler();
    10099      base.OnModelChanged();
     
    137136    }
    138137
    139     public void SetAccuracyMaximizingThresholds() {
    140       double[] classValues;
    141       double[] thresholds;
    142       var targetClassValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices);
    143       AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    144 
    145       Model.SetThresholdsAndClassValues(thresholds, classValues);
    146     }
    147 
    148     public void SetClassDistibutionCutPointThresholds() {
    149       double[] classValues;
    150       double[] thresholds;
    151       var targetClassValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices);
    152       NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    153 
    154       Model.SetThresholdsAndClassValues(thresholds, classValues);
    155     }
    156 
    157138    protected virtual void OnModelThresholdsChanged(EventArgs e) {
    158139      CalculateResults();
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/AverageThresholdCalculator.cs

    r8297 r8534  
    5858
    5959    protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) {
    60       // only works with binary classification
    61       if (!classValues.Count().Equals(2))
    62         return double.NaN;
    6360      Dataset dataset = solutions.First().ProblemData.Dataset;
    6461      IList<double> values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();
     
    7067
    7168    public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
    72       if (!classValues.Count().Equals(2))
    73         return Enumerable.Repeat(double.NaN, indices.Count());
    74 
    7569      Dataset dataset = solutions.First().ProblemData.Dataset;
    7670      double[][] values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, indices).ToArray()).ToArray();
     
    8781
    8882    protected double GetAverageConfidence(double avg, double estimatedClassValue) {
    89       if (estimatedClassValue.Equals(classValues[0])) {
    90         if (avg < estimatedClassValue)
    91           return 1;
    92         else if (avg >= threshold[1])
    93           return 0;
    94         else {
    95           double distance = threshold[1] - classValues[0];
    96           return (1 / distance) * (threshold[1] - avg);
     83      for (int i = 0; i < classValues.Length; i++) {
     84        if (estimatedClassValue.Equals(classValues[i])) {
     85          //special case: avgerage is higher than value of highest class
     86          if (i == classValues.Length - 1 && avg > estimatedClassValue) {
     87            return 1;
     88          }
     89          //special case: average is lower than value of lowest class
     90          if (i == 0 && avg < estimatedClassValue) {
     91            return 1;
     92          }
     93          //special case: average is not between threshold of estimated class value
     94          if ((i < classValues.Length - 1 && avg >= threshold[i + 1]) || avg <= threshold[i]) {
     95            return 0;
     96          }
     97
     98          double thresholdToClassDistance, thresholdToAverageValueDistance;
     99          if (avg >= classValues[i]) {
     100            thresholdToClassDistance = threshold[i + 1] - classValues[i];
     101            thresholdToAverageValueDistance = threshold[i + 1] - avg;
     102          } else {
     103            thresholdToClassDistance = classValues[i] - threshold[i];
     104            thresholdToAverageValueDistance = avg - threshold[i];
     105          }
     106          return (1 / thresholdToClassDistance) * thresholdToAverageValueDistance;
    97107        }
    98       } else if (estimatedClassValue.Equals(classValues[1])) {
    99         if (avg > estimatedClassValue)
    100           return 1;
    101         else if (avg <= threshold[1])
    102           return 0;
    103         else {
    104           double distance = classValues[1] - threshold[1];
    105           return (1 / distance) * (avg - threshold[1]);
    106         }
    107       } else
    108         return double.NaN;
     108      }
     109      return double.NaN;
    109110    }
    110111  }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MedianThresholdCalculator.cs

    r8297 r8534  
    5858
    5959    protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) {
    60       // only works with binary classification
    61       if (!classValues.Count().Equals(2))
    62         return double.NaN;
     60
    6361      Dataset dataset = solutions.First().ProblemData.Dataset;
    6462      IList<double> values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();
     
    7068
    7169    public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
    72       if (!classValues.Count().Equals(2))
    73         return Enumerable.Repeat(double.NaN, indices.Count());
    7470
    7571      Dataset dataset = solutions.First().ProblemData.Dataset;
     
    8783
    8884    protected double GetMedianConfidence(double median, double estimatedClassValue) {
    89       if (estimatedClassValue.Equals(classValues[0])) {
    90         if (median < estimatedClassValue)
    91           return 1;
    92         else if (median >= threshold[1])
    93           return 0;
    94         else {
    95           double distance = threshold[1] - classValues[0];
    96           return (1 / distance) * (threshold[1] - median);
     85      for (int i = 0; i < classValues.Length; i++) {
     86        if (estimatedClassValue.Equals(classValues[i])) {
     87          //special case: avgerage is higher than value of highest class
     88          if (i == classValues.Length - 1 && median > estimatedClassValue) {
     89            return 1;
     90          }
     91          //special case: average is lower than value of lowest class
     92          if (i == 0 && median < estimatedClassValue) {
     93            return 1;
     94          }
     95          //special case: average is not between threshold of estimated class value
     96          if ((i < classValues.Length - 1 && median >= threshold[i + 1]) || median <= threshold[i]) {
     97            return 0;
     98          }
     99
     100          double thresholdToClassDistance, thresholdToAverageValueDistance;
     101          if (median >= classValues[i]) {
     102            thresholdToClassDistance = threshold[i + 1] - classValues[i];
     103            thresholdToAverageValueDistance = threshold[i + 1] - median;
     104          } else {
     105            thresholdToClassDistance = classValues[i] - threshold[i];
     106            thresholdToAverageValueDistance = median - threshold[i];
     107          }
     108          return (1 / thresholdToClassDistance) * thresholdToAverageValueDistance;
    97109        }
    98       } else if (estimatedClassValue.Equals(classValues[1])) {
    99         if (median > estimatedClassValue)
    100           return 1;
    101         else if (median <= threshold[1])
    102           return 0;
    103         else {
    104           double distance = classValues[1] - threshold[1];
    105           return (1 / distance) * (median - threshold[1]);
    106         }
    107       } else
    108         return double.NaN;
     110      }
     111      return double.NaN;
    109112    }
    110113
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/ConstantRegressionModel.cs

    r7259 r8534  
    5555
    5656    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    57       return new ConstantRegressionSolution(this, problemData);
     57      return new ConstantRegressionSolution(this, new RegressionProblemData(problemData));
    5858    }
    5959  }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleModel.cs

    r7259 r8534  
    102102
    103103    public RegressionEnsembleSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    104       return new RegressionEnsembleSolution(this.Models, problemData);
     104      return new RegressionEnsembleSolution(this.Models, new RegressionEnsembleProblemData(problemData));
    105105    }
    106106    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionProblemData.cs

    r8508 r8534  
    121121      : this(defaultDataset, defaultAllowedInputVariables, defaultTargetVariable) {
    122122    }
     123    public RegressionProblemData(IRegressionProblemData regressionProblemData)
     124      : this(regressionProblemData.Dataset, regressionProblemData.AllowedInputVariables, regressionProblemData.TargetVariable) {
     125      TrainingPartition.Start = regressionProblemData.TrainingPartition.Start;
     126      TrainingPartition.End = regressionProblemData.TrainingPartition.End;
     127      TestPartition.Start = regressionProblemData.TestPartition.Start;
     128      TestPartition.End = regressionProblemData.TestPartition.End;
     129    }
    123130
    124131    public RegressionProblemData(Dataset dataset, IEnumerable<string> allowedInputVariables, string targetVariable)
Note: See TracChangeset for help on using the changeset viewer.