Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicNearestNeighbourClassificationModel.cs @ 9002

Last change on this file since 9002 was 9002, checked in by abeham, 11 years ago

#1943: modified nearest neighbor model to include number of samples at each estimated value

File size: 8.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29
30namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification {
31  /// <summary>
32  /// Represents a nearest neighbour model for regression and classification
33  /// </summary>
34  [StorableClass]
35  [Item("SymbolicNearestNeighbourClassificationModel", "Represents a nearest neighbour model for symbolic classification.")]
36  public sealed class SymbolicNearestNeighbourClassificationModel : SymbolicClassificationModel {
37
38    [Storable]
39    private int k;
40    [Storable]
41    private List<KeyValuePair<double, Dictionary<double, int>>> trainedTargetPair;
42    [Storable]
43    private ClassFrequencyComparer frequencyComparer;
44
45    [StorableConstructor]
46    private SymbolicNearestNeighbourClassificationModel(bool deserializing) : base(deserializing) { }
47    private SymbolicNearestNeighbourClassificationModel(SymbolicNearestNeighbourClassificationModel original, Cloner cloner)
48      : base(original, cloner) {
49      k = original.k;
50      trainedTargetPair = original.trainedTargetPair.Select(x => new KeyValuePair<double, Dictionary<double, int>>(x.Key, new Dictionary<double, int>(x.Value))).ToList();
51      frequencyComparer = new ClassFrequencyComparer(original.frequencyComparer);
52    }
53    public SymbolicNearestNeighbourClassificationModel(int k, ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue)
54      : base(tree, interpreter, lowerEstimationLimit, upperEstimationLimit) {
55      this.k = k;
56      trainedTargetPair = new List<KeyValuePair<double, Dictionary<double, int>>>();
57      frequencyComparer = new ClassFrequencyComparer();
58    }
59
60    public override IDeepCloneable Clone(Cloner cloner) {
61      return new SymbolicNearestNeighbourClassificationModel(this, cloner);
62    }
63
64    public override IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
65      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows)
66                                       .LimitToRange(LowerEstimationLimit, UpperEstimationLimit);
67      var neighborClasses = new Dictionary<double, int>();
68      foreach (var ev in estimatedValues) {
69        // find the index of the training-point to which distance is shortest
70        var upper = trainedTargetPair.BinarySearch(0, trainedTargetPair.Count, new KeyValuePair<double, Dictionary<double, int>>(ev, null), new KeyValuePairKeyComparer<Dictionary<double, int>>());
71        if (upper < 0) upper = ~upper; // if the element was not found exactly, BinarySearch returns the complement of the index of the next larger item
72        var lower = upper - 1;
73        neighborClasses.Clear();
74        // continue to the left and right of this index and look for the nearest neighbors
75        for (int i = 0; i < Math.Min(k, trainedTargetPair.Count); i++) {
76          bool lowerIsCloser = upper >= trainedTargetPair.Count || (lower >= 0 && ev - trainedTargetPair[lower].Key <= trainedTargetPair[upper].Key - ev);
77          bool upperIsCloser = lower < 0 || (upper < trainedTargetPair.Count && ev - trainedTargetPair[lower].Key >= trainedTargetPair[upper].Key - ev);
78          if (lowerIsCloser) {
79            // the nearer neighbor is to the left
80            var lowerClassSamples = trainedTargetPair[lower].Value.Select(x => x.Value).Sum(); // should be 1, except when multiple samples are estimated the same value
81            var lowerClasses = trainedTargetPair[lower].Value;
82            foreach (var lowerClass in lowerClasses) {
83              if (!neighborClasses.ContainsKey(lowerClass.Key)) neighborClasses[lowerClass.Key] = lowerClass.Value;
84              else neighborClasses[lowerClass.Key] += lowerClass.Value;
85            }
86            lower--;
87            i += (lowerClassSamples - 1);
88          }
89          // they could, in very rare cases, be equally far apart
90          if (upperIsCloser) {
91            // the nearer neighbor is to the right
92            var upperClassSamples = trainedTargetPair[upper].Value.Select(x => x.Value).Sum(); // should be 1, except when multiple samples are estimated the same value
93            var upperClasses = trainedTargetPair[upper].Value;
94            foreach (var upperClass in upperClasses) {
95              if (!neighborClasses.ContainsKey(upperClass.Key)) neighborClasses[upperClass.Key] = upperClass.Value;
96              else neighborClasses[upperClass.Key] += upperClass.Value;
97            }
98            upper++;
99            i += (upperClassSamples - 1);
100          }
101        }
102        // majority voting with preference for bigger class in case of tie
103        yield return neighborClasses.MaxItems(x => x.Value).OrderByDescending(x => x.Key, frequencyComparer).First().Key;
104      }
105    }
106
107    public override void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable<int> rows) {
108      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, problemData.Dataset, rows)
109                                       .LimitToRange(LowerEstimationLimit, UpperEstimationLimit);
110      var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
111      var pair = estimatedValues.Zip(targetValues, (e, t) => new { Estimated = e, Target = t });
112
113      // there could be more than one target value per estimated value
114      var dict = new Dictionary<double, Dictionary<double, int>>();
115      var classFrequencies = new Dictionary<double, int>();
116      foreach (var p in pair) {
117        // get target values for each estimated value (foreach estimated value there may be different target values in different amounts)
118        if (!dict.ContainsKey(p.Estimated)) dict[p.Estimated] = new Dictionary<double, int>();
119        if (!dict[p.Estimated].ContainsKey(p.Target)) dict[p.Estimated][p.Target] = 0;
120        dict[p.Estimated][p.Target]++;
121        // get class frequencies
122        if (!classFrequencies.ContainsKey(p.Target))
123          classFrequencies[p.Target] = 1;
124        else classFrequencies[p.Target]++;
125      }
126
127      frequencyComparer = new ClassFrequencyComparer(classFrequencies);
128
129      trainedTargetPair = new List<KeyValuePair<double, Dictionary<double, int>>>();
130      foreach (var ev in dict) {
131        trainedTargetPair.Add(new KeyValuePair<double, Dictionary<double, int>>(ev.Key, ev.Value));
132      }
133      trainedTargetPair = trainedTargetPair.OrderBy(x => x.Key).ToList();
134    }
135
136    public override ISymbolicClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
137      return new SymbolicClassificationSolution((ISymbolicClassificationModel)Clone(), problemData);
138    }
139  }
140
141  internal class KeyValuePairKeyComparer<T> : IComparer<KeyValuePair<double, T>> {
142    public int Compare(KeyValuePair<double, T> x, KeyValuePair<double, T> y) {
143      return x.Key.CompareTo(y.Key);
144    }
145  }
146
147  [StorableClass]
148  internal sealed class ClassFrequencyComparer : IComparer<double> {
149    [Storable]
150    private readonly Dictionary<double, int> classFrequencies;
151
152    [StorableConstructor]
153    private ClassFrequencyComparer(bool deserializing) { }
154    public ClassFrequencyComparer() {
155      classFrequencies = new Dictionary<double, int>();
156    }
157    public ClassFrequencyComparer(Dictionary<double, int> frequencies) {
158      classFrequencies = frequencies;
159    }
160    public ClassFrequencyComparer(ClassFrequencyComparer original) {
161      classFrequencies = new Dictionary<double, int>(original.classFrequencies);
162    }
163
164    public int Compare(double x, double y) {
165      bool cx = classFrequencies.ContainsKey(x), cy = classFrequencies.ContainsKey(y);
166      if (cx && cy)
167        return classFrequencies[x].CompareTo(classFrequencies[y]);
168      if (cx) return 1;
169      return -1;
170    }
171  }
172}
Note: See TracBrowser for help on using the repository browser.