Changeset 8978


Ignore:
Timestamp:
11/30/12 12:13:26 (10 years ago)
Author:
abeham
Message:

#1943: review comments

Location:
trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/ModelCreators/NearestNeighborModelCreator.cs

    r8606 r8978  
    4848    public NearestNeighborModelCreator()
    4949      : base() {
    50       Parameters.Add(new FixedValueParameter<IntValue>("K", "The number of neighbours to use to determine the class.", new IntValue(3)));
     50      Parameters.Add(new FixedValueParameter<IntValue>("K", "The number of neighbours to use to determine the class.", new IntValue(11)));
    5151    }
    5252
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicNearestNeighbourClassificationModel.cs

    r8606 r8978  
    4040    [Storable]
    4141    private List<KeyValuePair<double, double>> trainedTargetPair;
     42    [Storable]
     43    private ClassFrequencyComparer frequencyComparer;
    4244
    4345    [StorableConstructor]
     
    4749      k = original.k;
    4850      trainedTargetPair = new List<KeyValuePair<double, double>>(original.trainedTargetPair);
     51      frequencyComparer = new ClassFrequencyComparer(original.frequencyComparer);
    4952    }
    5053    public SymbolicNearestNeighbourClassificationModel(int k, ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue)
     
    5255      this.k = k;
    5356      this.trainedTargetPair = new List<KeyValuePair<double, double>>();
     57      frequencyComparer = new ClassFrequencyComparer();
    5458    }
    5559
     
    5862    }
    5963
     64    [StorableHook(HookType.AfterDeserialization)]
     65    private void AfterDeserialization() {
     66      if (frequencyComparer == null) {
     67        var dict = trainedTargetPair
     68          .GroupBy(x => x.Value)
     69          .ToDictionary(x => x.Key, y => y.Count());
     70        frequencyComparer = new ClassFrequencyComparer(dict);
     71      }
     72    }
     73
    6074    public override IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
    61       var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows);
    62       var neighbors = new Dictionary<double, int>();
     75      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows)
     76                                       .LimitToRange(LowerEstimationLimit, UpperEstimationLimit);
     77      var neighborClasses = new Dictionary<double, int>();
    6378      foreach (var ev in estimatedValues) {
    64         int lower = 0, upper = 1;
    65         double sdist = Math.Abs(ev - trainedTargetPair[0].Key);
    66         for (int i = 1; i < trainedTargetPair.Count; i++) {
    67           double d = Math.Abs(ev - trainedTargetPair[i].Key);
    68           if (d > sdist) break;
    69           lower = i;
    70           upper = i + 1;
    71           sdist = d;
    72         }
    73         neighbors.Clear();
    74         neighbors[trainedTargetPair[lower].Value] = 1;
    75         lower--;
    76         for (int i = 1; i < Math.Min(k, trainedTargetPair.Count); i++) {
     79        // find the index of the training-point to which distance is shortest
     80        var upper = trainedTargetPair.BinarySearch(0, trainedTargetPair.Count, new KeyValuePair<double, double>(ev, double.NaN), new KeyValuePairKeyComparer());
     81        if (upper < 0) upper = ~upper; // if the element was not found exactly, BinarySearch returns the complement of the index of the next larger item
     82        var lower = upper - 1;
     83        neighborClasses.Clear();
     84        // continue to the left and right of this index and look for the nearest neighbors
     85        for (int i = 0; i < Math.Min(k, trainedTargetPair.Count); i++) {
    7786          if (upper >= trainedTargetPair.Count || (lower > 0 && ev - trainedTargetPair[lower].Key < trainedTargetPair[upper].Key - ev)) {
    78             if (!neighbors.ContainsKey(trainedTargetPair[lower].Value))
    79               neighbors[trainedTargetPair[lower].Value] = 1;
    80             else neighbors[trainedTargetPair[lower].Value]++;
     87            // the nearer neighbor is to the left
     88            var lowerClass = trainedTargetPair[lower].Value;
     89            if (!neighborClasses.ContainsKey(lowerClass)) neighborClasses[lowerClass] = 1;
     90            else neighborClasses[lowerClass]++;
    8191            lower--;
    8292          } else {
    83             if (!neighbors.ContainsKey(trainedTargetPair[upper].Value))
    84               neighbors[trainedTargetPair[upper].Value] = 1;
    85             else neighbors[trainedTargetPair[upper].Value]++;
     93            // the nearer neighbor is to the right
     94            var upperClass = trainedTargetPair[upper].Value;
     95            if (!neighborClasses.ContainsKey(upperClass)) neighborClasses[upperClass] = 1;
     96            else neighborClasses[upperClass]++;
    8697            upper++;
    8798          }
    8899        }
    89         yield return neighbors.MaxItems(x => x.Value).First().Key;
     100        // majority voting with preference for bigger class in case of tie
     101        yield return neighborClasses.MaxItems(x => x.Value).OrderByDescending(x => x.Key, frequencyComparer).First().Key;
    90102      }
    91103    }
    92104
    93105    public override void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable<int> rows) {
    94       var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, problemData.Dataset, rows);
     106      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, problemData.Dataset, rows)
     107                                       .LimitToRange(LowerEstimationLimit, UpperEstimationLimit);
    95108      var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
    96109      var pair = estimatedValues.Zip(targetValues, (e, t) => new { Estimated = e, Target = t });
     
    98111      // there could be more than one target value per estimated value
    99112      var dict = new Dictionary<double, Dictionary<double, int>>();
     113      var classFrequencies = new Dictionary<double, int>();
    100114      foreach (var p in pair) {
    101115        if (!dict.ContainsKey(p.Estimated)) dict[p.Estimated] = new Dictionary<double, int>();
    102116        if (!dict[p.Estimated].ContainsKey(p.Target)) dict[p.Estimated][p.Target] = 0;
    103117        dict[p.Estimated][p.Target]++;
     118
     119        if (!classFrequencies.ContainsKey(p.Target))
     120          classFrequencies[p.Target] = 1;
     121        else classFrequencies[p.Target]++;
    104122      }
     123
     124      frequencyComparer = new ClassFrequencyComparer(classFrequencies);
    105125
    106126      trainedTargetPair = new List<KeyValuePair<double, double>>();
    107127      foreach (var ev in dict) {
    108         var target = ev.Value.MaxItems(x => x.Value).First().Key;
     128        var target = ev.Value.MaxItems(x => x.Value).OrderByDescending(x => x.Key, frequencyComparer).First().Key;
    109129        trainedTargetPair.Add(new KeyValuePair<double, double>(ev.Key, target));
    110130      }
     
    116136    }
    117137  }
     138
     139  internal class KeyValuePairKeyComparer : IComparer<KeyValuePair<double, double>> {
     140    public int Compare(KeyValuePair<double, double> x, KeyValuePair<double, double> y) {
     141      return x.Key.CompareTo(y.Key);
     142    }
     143  }
     144
     145  [StorableClass]
     146  internal class ClassFrequencyComparer : IComparer<double> {
     147    [Storable]
     148    private Dictionary<double, int> classFrequencies;
     149
     150    [StorableConstructor]
     151    private ClassFrequencyComparer(bool deserializing) { }
     152    public ClassFrequencyComparer() {
     153      classFrequencies = new Dictionary<double, int>();
     154    }
     155    public ClassFrequencyComparer(Dictionary<double, int> frequencies) {
     156      classFrequencies = frequencies;
     157    }
     158    public ClassFrequencyComparer(ClassFrequencyComparer original) {
     159      classFrequencies = new Dictionary<double, int>(original.classFrequencies);
     160    }
     161
     162    public int Compare(double x, double y) {
     163      bool cx = classFrequencies.ContainsKey(x), cy = classFrequencies.ContainsKey(y);
     164      if (cx && cy)
     165        return classFrequencies[x].CompareTo(classFrequencies[y]);
     166      if (cx) return 1;
     167      return -1;
     168    }
     169  }
    118170}
Note: See TracChangeset for help on using the changeset viewer.