Free cookie consent management tool by TermsFeed Policy Generator

Changeset 9003


Ignore:
Timestamp:
12/05/12 18:54:25 (11 years ago)
Author:
gkronber
Message:

#1943: overhauled nearest neighbor routine

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicNearestNeighbourClassificationModel.cs

    r9002 r9003  
    3939    private int k;
    4040    [Storable]
    41     private List<KeyValuePair<double, Dictionary<double, int>>> trainedTargetPair;
     41    private List<double> trainedClasses;
     42    [Storable]
     43    private List<double> trainedEstimatedValues;
     44
    4245    [Storable]
    4346    private ClassFrequencyComparer frequencyComparer;
     
    4851      : base(original, cloner) {
    4952      k = original.k;
    50       trainedTargetPair = original.trainedTargetPair.Select(x => new KeyValuePair<double, Dictionary<double, int>>(x.Key, new Dictionary<double, int>(x.Value))).ToList();
    5153      frequencyComparer = new ClassFrequencyComparer(original.frequencyComparer);
     54      trainedEstimatedValues = new List<double>(original.trainedEstimatedValues);
     55      trainedClasses = new List<double>(original.trainedClasses);
    5256    }
    5357    public SymbolicNearestNeighbourClassificationModel(int k, ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue)
    5458      : base(tree, interpreter, lowerEstimationLimit, upperEstimationLimit) {
    5559      this.k = k;
    56       trainedTargetPair = new List<KeyValuePair<double, Dictionary<double, int>>>();
    5760      frequencyComparer = new ClassFrequencyComparer();
     61
    5862    }
    5963
     
    6569      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows)
    6670                                       .LimitToRange(LowerEstimationLimit, UpperEstimationLimit);
    67       var neighborClasses = new Dictionary<double, int>();
    6871      foreach (var ev in estimatedValues) {
     72        // find the range [lower, upper[ of trainedTargetValues that contains the k closest neighbours
     73        // the range can span more than k elements when there are equal estimated values
     74
    6975        // find the index of the training-point to which distance is shortest
    70         var upper = trainedTargetPair.BinarySearch(0, trainedTargetPair.Count, new KeyValuePair<double, Dictionary<double, int>>(ev, null), new KeyValuePairKeyComparer<Dictionary<double, int>>());
    71         if (upper < 0) upper = ~upper; // if the element was not found exactly, BinarySearch returns the complement of the index of the next larger item
    72         var lower = upper - 1;
    73         neighborClasses.Clear();
    74         // continue to the left and right of this index and look for the nearest neighbors
    75         for (int i = 0; i < Math.Min(k, trainedTargetPair.Count); i++) {
    76           bool lowerIsCloser = upper >= trainedTargetPair.Count || (lower >= 0 && ev - trainedTargetPair[lower].Key <= trainedTargetPair[upper].Key - ev);
    77           bool upperIsCloser = lower < 0 || (upper < trainedTargetPair.Count && ev - trainedTargetPair[lower].Key >= trainedTargetPair[upper].Key - ev);
     76        int lower = trainedEstimatedValues.BinarySearch(ev);
     77        int upper;
     78        // if the element was not found exactly, BinarySearch returns the complement of the index of the next larger item
     79        if (lower < 0) {
     80          lower = ~lower;
     81          // lower is not necessarily the closer one
     82          // determine which element is closer to ev (lower - 1) or (lower)
     83          if (lower == trainedEstimatedValues.Count ||
     84            (lower > 0 && Math.Abs(ev - trainedEstimatedValues[lower - 1]) < Math.Abs(ev - trainedEstimatedValues[lower]))) {
     85            lower = lower - 1;
     86          }
     87        }
     88        upper = lower + 1;
     89        // at this point we have a range [lower, upper[ that includes only the closest element to ev
     90
     91        // expand the range to left or right looking for the nearest neighbors
     92        while (upper - lower < Math.Min(k, trainedEstimatedValues.Count)) {
     93          bool lowerIsCloser = upper >= trainedEstimatedValues.Count ||
     94                               (lower > 0 && ev - trainedEstimatedValues[lower] <= trainedEstimatedValues[upper] - ev);
     95          bool upperIsCloser = lower <= 0 ||
     96                               (upper < trainedEstimatedValues.Count &&
     97                                ev - trainedEstimatedValues[lower] >= trainedEstimatedValues[upper] - ev);
     98          if (!lowerIsCloser && !upperIsCloser) break;
    7899          if (lowerIsCloser) {
    79             // the nearer neighbor is to the left
    80             var lowerClassSamples = trainedTargetPair[lower].Value.Select(x => x.Value).Sum(); // should be 1, except when multiple samples are estimated the same value
    81             var lowerClasses = trainedTargetPair[lower].Value;
    82             foreach (var lowerClass in lowerClasses) {
    83               if (!neighborClasses.ContainsKey(lowerClass.Key)) neighborClasses[lowerClass.Key] = lowerClass.Value;
    84               else neighborClasses[lowerClass.Key] += lowerClass.Value;
    85             }
    86100            lower--;
    87             i += (lowerClassSamples - 1);
     101            // eat up all equal values
     102            while (lower > 0 && trainedEstimatedValues[lower - 1].IsAlmost(trainedEstimatedValues[lower]))
     103              lower--;
    88104          }
    89           // they could, in very rare cases, be equally far apart
    90105          if (upperIsCloser) {
    91             // the nearer neighbor is to the right
    92             var upperClassSamples = trainedTargetPair[upper].Value.Select(x => x.Value).Sum(); // should be 1, except when multiple samples are estimated the same value
    93             var upperClasses = trainedTargetPair[upper].Value;
    94             foreach (var upperClass in upperClasses) {
    95               if (!neighborClasses.ContainsKey(upperClass.Key)) neighborClasses[upperClass.Key] = upperClass.Value;
    96               else neighborClasses[upperClass.Key] += upperClass.Value;
    97             }
    98106            upper++;
    99             i += (upperClassSamples - 1);
     107            while (upper < trainedEstimatedValues.Count &&
     108                   trainedEstimatedValues[upper - 1].IsAlmost(trainedEstimatedValues[upper]))
     109              upper++;
    100110          }
    101111        }
    102112        // majority voting with preference for bigger class in case of tie
    103         yield return neighborClasses.MaxItems(x => x.Value).OrderByDescending(x => x.Key, frequencyComparer).First().Key;
     113        yield return Enumerable.Range(lower, upper - lower)
     114          .Select(i => trainedClasses[i])
     115          .GroupBy(c => c)
     116          .Select(g => new { Class = g.Key, Votes = g.Count() })
     117          .MaxItems(p => p.Votes)
     118          .OrderByDescending(m => m.Class, frequencyComparer)
     119          .First().Class;
    104120      }
    105121    }
     
    109125                                       .LimitToRange(LowerEstimationLimit, UpperEstimationLimit);
    110126      var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
    111       var pair = estimatedValues.Zip(targetValues, (e, t) => new { Estimated = e, Target = t });
     127      var trainedClasses = targetValues.ToArray();
     128      var trainedEstimatedValues = estimatedValues.ToArray();
    112129
    113       // there could be more than one target value per estimated value
    114       var dict = new Dictionary<double, Dictionary<double, int>>();
    115       var classFrequencies = new Dictionary<double, int>();
    116       foreach (var p in pair) {
    117         // get target values for each estimated value (foreach estimated value there may be different target values in different amounts)
    118         if (!dict.ContainsKey(p.Estimated)) dict[p.Estimated] = new Dictionary<double, int>();
    119         if (!dict[p.Estimated].ContainsKey(p.Target)) dict[p.Estimated][p.Target] = 0;
    120         dict[p.Estimated][p.Target]++;
    121         // get class frequencies
    122         if (!classFrequencies.ContainsKey(p.Target))
    123           classFrequencies[p.Target] = 1;
    124         else classFrequencies[p.Target]++;
    125       }
     130      Array.Sort(trainedEstimatedValues, trainedClasses);
     131      this.trainedClasses = new List<double>(trainedClasses);
     132      this.trainedEstimatedValues = new List<double>(trainedEstimatedValues);
    126133
    127       frequencyComparer = new ClassFrequencyComparer(classFrequencies);
    128 
    129       trainedTargetPair = new List<KeyValuePair<double, Dictionary<double, int>>>();
    130       foreach (var ev in dict) {
    131         trainedTargetPair.Add(new KeyValuePair<double, Dictionary<double, int>>(ev.Key, ev.Value));
    132       }
    133       trainedTargetPair = trainedTargetPair.OrderBy(x => x.Key).ToList();
     134      var freq = trainedClasses
     135        .GroupBy(c => c)
     136        .ToDictionary(g => g.Key, g => g.Count());
     137      this.frequencyComparer = new ClassFrequencyComparer(freq);
    134138    }
    135139
    136140    public override ISymbolicClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
    137141      return new SymbolicClassificationSolution((ISymbolicClassificationModel)Clone(), problemData);
    138     }
    139   }
    140 
    141   internal class KeyValuePairKeyComparer<T> : IComparer<KeyValuePair<double, T>> {
    142     public int Compare(KeyValuePair<double, T> x, KeyValuePair<double, T> y) {
    143       return x.Key.CompareTo(y.Key);
    144142    }
    145143  }
Note: See TracChangeset for help on using the changeset viewer.