Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
12/05/12 17:03:53 (11 years ago)
Author:
abeham
Message:

#1943: modified nearest neighbor model to include number of samples at each estimated value

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicNearestNeighbourClassificationModel.cs

    r8979 r9002  
    3939    private int k;
    4040    [Storable]
    41     private List<KeyValuePair<double, double>> trainedTargetPair;
     41    private List<KeyValuePair<double, Dictionary<double, int>>> trainedTargetPair;
    4242    [Storable]
    4343    private ClassFrequencyComparer frequencyComparer;
     
    4848      : base(original, cloner) {
    4949      k = original.k;
    50       trainedTargetPair = new List<KeyValuePair<double, double>>(original.trainedTargetPair);
     50      trainedTargetPair = original.trainedTargetPair.Select(x => new KeyValuePair<double, Dictionary<double, int>>(x.Key, new Dictionary<double, int>(x.Value))).ToList();
    5151      frequencyComparer = new ClassFrequencyComparer(original.frequencyComparer);
    5252    }
     
    5454      : base(tree, interpreter, lowerEstimationLimit, upperEstimationLimit) {
    5555      this.k = k;
    56       this.trainedTargetPair = new List<KeyValuePair<double, double>>();
     56      trainedTargetPair = new List<KeyValuePair<double, Dictionary<double, int>>>();
    5757      frequencyComparer = new ClassFrequencyComparer();
    5858    }
     
    6060    public override IDeepCloneable Clone(Cloner cloner) {
    6161      return new SymbolicNearestNeighbourClassificationModel(this, cloner);
    62     }
    63 
    64     [StorableHook(HookType.AfterDeserialization)]
    65     private void AfterDeserialization() {
    66       if (frequencyComparer == null) {
    67         var dict = trainedTargetPair
    68           .GroupBy(x => x.Value)
    69           .ToDictionary(x => x.Key, y => y.Count());
    70         frequencyComparer = new ClassFrequencyComparer(dict);
    71       }
    7262    }
    7363
     
    7868      foreach (var ev in estimatedValues) {
    7969        // find the index of the training-point to which distance is shortest
    80         var upper = trainedTargetPair.BinarySearch(0, trainedTargetPair.Count, new KeyValuePair<double, double>(ev, double.NaN), new KeyValuePairKeyComparer());
     70        var upper = trainedTargetPair.BinarySearch(0, trainedTargetPair.Count, new KeyValuePair<double, Dictionary<double, int>>(ev, null), new KeyValuePairKeyComparer<Dictionary<double, int>>());
    8171        if (upper < 0) upper = ~upper; // if the element was not found exactly, BinarySearch returns the complement of the index of the next larger item
    8272        var lower = upper - 1;
     
    8474        // continue to the left and right of this index and look for the nearest neighbors
    8575        for (int i = 0; i < Math.Min(k, trainedTargetPair.Count); i++) {
    86           if (upper >= trainedTargetPair.Count || (lower > 0 && ev - trainedTargetPair[lower].Key < trainedTargetPair[upper].Key - ev)) {
     76          bool lowerIsCloser = upper >= trainedTargetPair.Count || (lower >= 0 && ev - trainedTargetPair[lower].Key <= trainedTargetPair[upper].Key - ev);
     77          bool upperIsCloser = lower < 0 || (upper < trainedTargetPair.Count && ev - trainedTargetPair[lower].Key >= trainedTargetPair[upper].Key - ev);
     78          if (lowerIsCloser) {
    8779            // the nearer neighbor is to the left
    88             var lowerClass = trainedTargetPair[lower].Value;
    89             if (!neighborClasses.ContainsKey(lowerClass)) neighborClasses[lowerClass] = 1;
    90             else neighborClasses[lowerClass]++;
     80            var lowerClassSamples = trainedTargetPair[lower].Value.Select(x => x.Value).Sum(); // should be 1, except when multiple samples are estimated the same value
     81            var lowerClasses = trainedTargetPair[lower].Value;
     82            foreach (var lowerClass in lowerClasses) {
     83              if (!neighborClasses.ContainsKey(lowerClass.Key)) neighborClasses[lowerClass.Key] = lowerClass.Value;
     84              else neighborClasses[lowerClass.Key] += lowerClass.Value;
     85            }
    9186            lower--;
    92           } else {
     87            i += (lowerClassSamples - 1);
     88          }
     89          // they could, in very rare cases, be equally far apart
     90          if (upperIsCloser) {
    9391            // the nearer neighbor is to the right
    94             var upperClass = trainedTargetPair[upper].Value;
    95             if (!neighborClasses.ContainsKey(upperClass)) neighborClasses[upperClass] = 1;
    96             else neighborClasses[upperClass]++;
     92            var upperClassSamples = trainedTargetPair[upper].Value.Select(x => x.Value).Sum(); // should be 1, except when multiple samples are estimated the same value
     93            var upperClasses = trainedTargetPair[upper].Value;
     94            foreach (var upperClass in upperClasses) {
     95              if (!neighborClasses.ContainsKey(upperClass.Key)) neighborClasses[upperClass.Key] = upperClass.Value;
     96              else neighborClasses[upperClass.Key] += upperClass.Value;
     97            }
    9798            upper++;
     99            i += (upperClassSamples - 1);
    98100          }
    99101        }
     
    113115      var classFrequencies = new Dictionary<double, int>();
    114116      foreach (var p in pair) {
     117        // get target values for each estimated value (foreach estimated value there may be different target values in different amounts)
    115118        if (!dict.ContainsKey(p.Estimated)) dict[p.Estimated] = new Dictionary<double, int>();
    116119        if (!dict[p.Estimated].ContainsKey(p.Target)) dict[p.Estimated][p.Target] = 0;
    117120        dict[p.Estimated][p.Target]++;
    118 
     121        // get class frequencies
    119122        if (!classFrequencies.ContainsKey(p.Target))
    120123          classFrequencies[p.Target] = 1;
     
    124127      frequencyComparer = new ClassFrequencyComparer(classFrequencies);
    125128
    126       trainedTargetPair = new List<KeyValuePair<double, double>>();
     129      trainedTargetPair = new List<KeyValuePair<double, Dictionary<double, int>>>();
    127130      foreach (var ev in dict) {
    128         var target = ev.Value.MaxItems(x => x.Value).OrderByDescending(x => x.Key, frequencyComparer).First().Key;
    129         trainedTargetPair.Add(new KeyValuePair<double, double>(ev.Key, target));
     131        trainedTargetPair.Add(new KeyValuePair<double, Dictionary<double, int>>(ev.Key, ev.Value));
    130132      }
    131133      trainedTargetPair = trainedTargetPair.OrderBy(x => x.Key).ToList();
     
    133135
    134136    public override ISymbolicClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
    135       return new SymbolicClassificationSolution((ISymbolicClassificationModel)this.Clone(), problemData);
     137      return new SymbolicClassificationSolution((ISymbolicClassificationModel)Clone(), problemData);
    136138    }
    137139  }
    138140
    139   internal class KeyValuePairKeyComparer : IComparer<KeyValuePair<double, double>> {
    140     public int Compare(KeyValuePair<double, double> x, KeyValuePair<double, double> y) {
     141  internal class KeyValuePairKeyComparer<T> : IComparer<KeyValuePair<double, T>> {
     142    public int Compare(KeyValuePair<double, T> x, KeyValuePair<double, T> y) {
    141143      return x.Key.CompareTo(y.Key);
    142144    }
     
    146148  internal sealed class ClassFrequencyComparer : IComparer<double> {
    147149    [Storable]
    148     private Dictionary<double, int> classFrequencies;
     150    private readonly Dictionary<double, int> classFrequencies;
    149151
    150152    [StorableConstructor]
Note: See TracChangeset for help on using the changeset viewer.