Free cookie consent management tool by TermsFeed Policy Generator

Changeset 7562


Ignore:
Timestamp:
03/06/12 15:08:13 (13 years ago)
Author:
sforsten
Message:

#1776:

  • bug fix in NeighbourhoodWeightCalculator
  • added GetConfidence method to IClassificationEnsembleSolutionWeightCalculator
  • adjusted the confidence column in ClassificationEnsembleSolutionEstimatedClassValuesView
Location:
branches/ClassificationEnsembleVoting
Files:
7 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationEnsembleSolutionEstimatedClassValuesView.cs

    r7531 r7562  
    2828using HeuristicLab.MainForm;
    2929using HeuristicLab.MainForm.WindowsForms;
     30using HeuristicLab.Problems.DataAnalysis.Interfaces.Classification;
    3031
    3132namespace HeuristicLab.Problems.DataAnalysis.Views {
     
    9697      }
    9798
     99      IEnumerable<IClassificationSolution> solutions = Content.ClassificationSolutions.CheckedItems;
    98100      int classValuesCount = Content.ProblemData.ClassValues.Count;
    99       int solutionsCount = Content.ClassificationSolutions.Count();
     101      int solutionsCount = solutions.Count();
    100102      string[,] values = new string[indizes.Length, 5 + classValuesCount + solutionsCount];
    101103      double[] target = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToArray();
    102104      List<List<double?>> estimatedValuesVector = GetEstimatedValues(SamplesComboBox.SelectedItem.ToString(), indizes,
    103                                                             Content.ClassificationSolutions);
    104       //List<double> weights = Content.Weights.ToList();
    105       //double weightSum = weights.Sum();
     105                                                            solutions);
     106
     107      IClassificationEnsembleSolutionWeightCalculator weightCalc = Content.WeightCalculator;
    106108
    107109      for (int i = 0; i < indizes.Length; i++) {
     
    113115          values[i, 2] = estimatedClassValues[i].ToString();
    114116          values[i, 3] = (target[i].IsAlmost(estimatedClassValues[i])).ToString();
    115 
    116           //currently disabled for test purpose
    117 
    118           //IEnumerable<int> indices = FindAllIndices(estimatedValuesVector[i], estimatedClassValues[i]);
    119           //double confidence = 0.0;
    120           //foreach (var index in indices) {
    121           //  confidence += weights[index];
    122           //}
    123           //values[i, 4] = (confidence / weightSum).ToString();
    124           //var estimationCount = groups.Where(g => g.Key != null).Select(g => g.Count).Sum();
    125           //values[i, 4] =
    126           //  (((double)groups.Where(g => g.Key == estimatedClassValues[i]).Single().Count) / estimationCount).ToString();
    127           values[i, 4] = "1.0";
     117          values[i, 4] = weightCalc.GetConfidence(solutions, indizes[i], estimatedClassValues[i]).ToString();
    128118
    129119          var groups =
     
    145135      List<string> columnNames = new List<string>() { "Id", TargetClassValuesColumnName, EstimatedClassValuesColumnName, CorrectClassificationColumnName, ConfidenceColumnName };
    146136      columnNames.AddRange(Content.ProblemData.ClassNames);
    147       columnNames.AddRange(Content.Model.Models.Select(m => m.Name));
     137      columnNames.AddRange(Content.ClassificationSolutions.CheckedItems.Select(s => s.Model.Name));//.Model.Models.Select(m => m.Name));
    148138      matrix.ColumnNames = columnNames;
    149139      matrix.SortableView = true;
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r7549 r7562  
    5151    }
    5252
    53     //[Storable]
    54     //private Dictionary<IClassificationModel, IntRange> trainingPartitions;
    55     //[Storable]
    56     //private Dictionary<IClassificationModel, IntRange> testPartitions;
    57 
    5853    private IClassificationEnsembleSolutionWeightCalculator weightCalculator;
    5954
     
    6661        }
    6762      }
     63      get { return weightCalculator; }
    6864    }
    6965
     
    7773      foreach (var model in Model.Models) {
    7874        IClassificationProblemData problemData = (IClassificationProblemData)ProblemData.Clone();
    79         //problemData.TrainingPartition.Start = trainingPartitions[model].Start;
    80         //problemData.TrainingPartition.End = trainingPartitions[model].End;
    81         //problemData.TestPartition.Start = testPartitions[model].Start;
    82         //problemData.TestPartition.End = testPartitions[model].End;
    83 
    8475        classificationSolutions.Add(model.CreateClassificationSolution(problemData));
    8576      }
     
    8980    private ClassificationEnsembleSolution(ClassificationEnsembleSolution original, Cloner cloner)
    9081      : base(original, cloner) {
    91       //trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
    92       //testPartitions = new Dictionary<IClassificationModel, IntRange>();
    93       //foreach (var pair in original.trainingPartitions) {
    94       //  trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
    95       //}
    96       //foreach (var pair in original.testPartitions) {
    97       //  testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
    98       //}
    99 
    10082      classificationSolutions = cloner.Clone(original.classificationSolutions);
    10183      RegisterClassificationSolutionsEventHandler();
     
    10486    public ClassificationEnsembleSolution()
    10587      : base(new ClassificationEnsembleModel(), ClassificationEnsembleProblemData.EmptyProblemData) {
    106       //trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
    107       //testPartitions = new Dictionary<IClassificationModel, IntRange>();
    10888      classificationSolutions = new CheckedItemCollection<IClassificationSolution>();
    10989
     
    121101    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
    122102      : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) {
    123       //this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
    124       //this.testPartitions = new Dictionary<IClassificationModel, IntRange>();
    125103      this.classificationSolutions = new CheckedItemCollection<IClassificationSolution>();
    126104
     
    217195          solution.ProblemData = problemData;
    218196      }
    219       //foreach (var trainingPartition in trainingPartitions.Values) {
    220       //  trainingPartition.Start = ProblemData.TrainingPartition.Start;
    221       //  trainingPartition.End = ProblemData.TrainingPartition.End;
    222       //}
    223       //foreach (var testPartition in testPartitions.Values) {
    224       //  testPartition.Start = ProblemData.TestPartition.Start;
    225       //  testPartition.End = ProblemData.TestPartition.End;
    226       //}
    227 
    228197      base.OnProblemDataChanged();
    229198    }
     
    256225      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
    257226      Model.Add(solution.Model);
    258       //trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
    259       //testPartitions[solution.Model] = solution.ProblemData.TestPartition;
    260227    }
    261228
     
    263230      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
    264231      Model.Remove(solution.Model);
    265       //trainingPartitions.Remove(solution.Model);
    266       //testPartitions.Remove(solution.Model);
    267232    }
    268233  }
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/ClassificationWeightCalculator.cs

    r7559 r7562  
    113113    }
    114114
     115    public virtual double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue) {
     116      if (solutions.Count() < 1)
     117        return double.NaN;
     118      Dataset dataset = solutions.First().ProblemData.Dataset;
     119      var correctSolutions = solutions.Select(s => new { Solution = s, Values = s.Model.GetEstimatedClassValues(dataset, Enumerable.Repeat(index, 1)).First() })
     120                                      .Where(a => a.Values.Equals(estimatedClassValue))
     121                                      .Select(a => a.Solution);
     122      return (from sol in correctSolutions
     123              select weights[sol]).Sum();
     124    }
     125
    115126    #region Helper
    116127    protected IEnumerable<double> GetValues(IList<double> targetValues, IEnumerable<int> indizes) {
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/DiscriminantClassificationWeightCalculator.cs

    r7559 r7562  
    8282      }
    8383    }
     84
     85    public sealed override double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue) {
     86      if (solutions.Count() < 1 || !solutions.All(x => x is IDiscriminantFunctionClassificationSolution))
     87        return double.NaN;
     88
     89      IEnumerable<IDiscriminantFunctionClassificationSolution> discriminantSolutions = solutions.Cast<IDiscriminantFunctionClassificationSolution>();
     90
     91      return GetDiscriminantConfidence(discriminantSolutions, index, estimatedClassValue);
     92    }
     93
     94    protected virtual double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) {
     95      return base.GetConfidence(solutions, index, estimatedClassValue);
     96    }
    8497  }
    8598}
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/MedianThresholdCalculator.cs

    r7549 r7562  
    2020#endregion
    2121
     22using System.Collections;
    2223using System.Collections.Generic;
    2324using System.Linq;
     
    8485      }
    8586      AccuracyMaximizationThresholdCalculator.CalculateThresholds(solutionProblemData[0], median, targetValues, out classValues, out threshold);
    86       return median;
     87      return Enumerable.Repeat<double>(1, discriminantSolutions.Count());
    8788    }
    8889
    8990    protected override double DiscriminantAggregateEstimatedClassValues(IDictionary<IClassificationSolution, double> estimatedClassValues, IDictionary<IClassificationSolution, double> estimatedValues) {
    90       double classValue = classValues.First();
    9191      IList<double> values = estimatedValues.Select(x => x.Value).ToList();
    9292      if (values.Count <= 0)
    9393        return double.NaN;
    9494      double median = GetMedian(values);
     95      return GetClassValueToMedian(median);
     96    }
     97    private double GetClassValueToMedian(double median) {
     98      double classValue = classValues.First();
    9599      for (int i = 0; i < classValues.Count(); i++) {
    96100        if (median > threshold[i])
     
    100104      }
    101105      return classValue;
     106    }
     107
     108    protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) {
     109      // only works with binary classification
     110      if (!classValues.Count().Equals(2))
     111        return double.NaN;
     112      Dataset dataset = solutions.First().ProblemData.Dataset;
     113      IList<double> values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();
     114      if (values.Count <= 0)
     115        return double.NaN;
     116      double median = GetMedian(values);
     117      if (estimatedClassValue.Equals(classValues[0])) {
     118        if (median < estimatedClassValue)
     119          return 1;
     120        else if (median >= threshold[1])
     121          return 0;
     122        else {
     123          double distance = threshold[1] - classValues[0];
     124          return (1 / distance) * (median - classValues[0]);
     125        }
     126      } else if (estimatedClassValue.Equals(classValues[1])) {
     127        if (median > estimatedClassValue)
     128          return 1;
     129        else if (median <= threshold[1])
     130          return 0;
     131        else {
     132          double distance = classValues[1] - threshold[1];
     133          return (1 / distance) * (classValues[1] - median);
     134        }
     135      } else
     136        return double.NaN;
    102137    }
    103138
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/NeighbourhoodWeightCalculator.cs

    r7549 r7562  
    5858      foreach (var solution in discriminantSolutions) {
    5959        estimatedValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList());
    60         estimatedClassValues.Add(solution.Model.GetEstimatedValues(dataSet, rows).ToList());
     60        estimatedClassValues.Add(solution.Model.GetEstimatedClassValues(dataSet, rows).ToList());
    6161      }
    6262
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IClassificationEnsembleSolutionWeightCalculator.cs

    r7549 r7562  
    2525namespace HeuristicLab.Problems.DataAnalysis.Interfaces.Classification {
    2626  public delegate bool CheckPoint(IClassificationProblemData problemData, int point);
     27
    2728  public interface IClassificationEnsembleSolutionWeightCalculator : INamedItem {
    2829    void CalculateNormalizedWeights(IEnumerable<IClassificationSolution> classificationSolutions);
    2930    IEnumerable<double> AggregateEstimatedClassValues(IEnumerable<IClassificationSolution> solutions, Dataset dataset, IEnumerable<int> rows, CheckPoint handler);
     31    double GetConfidence(IEnumerable<IClassificationSolution> solutions, int index, double estimatedClassValue);
     32
    3033    CheckPoint GetTestClassDelegate();
    3134    CheckPoint GetTrainingClassDelegate();
Note: See TracChangeset for help on using the changeset viewer.