Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicNearestNeighbourClassificationModel.cs @ 8978

Last change on this file since 8978 was 8978, checked in by abeham, 11 years ago

#1943: review comments

File size: 8.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29
30namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification {
31  /// <summary>
32  /// Represents a nearest neighbour model for regression and classification
33  /// </summary>
34  [StorableClass]
35  [Item("SymbolicNearestNeighbourClassificationModel", "Represents a nearest neighbour model for symbolic classification.")]
36  public sealed class SymbolicNearestNeighbourClassificationModel : SymbolicClassificationModel {
37
38    [Storable]
39    private int k;
40    [Storable]
41    private List<KeyValuePair<double, double>> trainedTargetPair;
42    [Storable]
43    private ClassFrequencyComparer frequencyComparer;
44
45    [StorableConstructor]
46    private SymbolicNearestNeighbourClassificationModel(bool deserializing) : base(deserializing) { }
47    private SymbolicNearestNeighbourClassificationModel(SymbolicNearestNeighbourClassificationModel original, Cloner cloner)
48      : base(original, cloner) {
49      k = original.k;
50      trainedTargetPair = new List<KeyValuePair<double, double>>(original.trainedTargetPair);
51      frequencyComparer = new ClassFrequencyComparer(original.frequencyComparer);
52    }
53    public SymbolicNearestNeighbourClassificationModel(int k, ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue)
54      : base(tree, interpreter, lowerEstimationLimit, upperEstimationLimit) {
55      this.k = k;
56      this.trainedTargetPair = new List<KeyValuePair<double, double>>();
57      frequencyComparer = new ClassFrequencyComparer();
58    }
59
60    public override IDeepCloneable Clone(Cloner cloner) {
61      return new SymbolicNearestNeighbourClassificationModel(this, cloner);
62    }
63
64    [StorableHook(HookType.AfterDeserialization)]
65    private void AfterDeserialization() {
66      if (frequencyComparer == null) {
67        var dict = trainedTargetPair
68          .GroupBy(x => x.Value)
69          .ToDictionary(x => x.Key, y => y.Count());
70        frequencyComparer = new ClassFrequencyComparer(dict);
71      }
72    }
73
74    public override IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
75      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows)
76                                       .LimitToRange(LowerEstimationLimit, UpperEstimationLimit);
77      var neighborClasses = new Dictionary<double, int>();
78      foreach (var ev in estimatedValues) {
79        // find the index of the training-point to which distance is shortest
80        var upper = trainedTargetPair.BinarySearch(0, trainedTargetPair.Count, new KeyValuePair<double, double>(ev, double.NaN), new KeyValuePairKeyComparer());
81        if (upper < 0) upper = ~upper; // if the element was not found exactly, BinarySearch returns the complement of the index of the next larger item
82        var lower = upper - 1;
83        neighborClasses.Clear();
84        // continue to the left and right of this index and look for the nearest neighbors
85        for (int i = 0; i < Math.Min(k, trainedTargetPair.Count); i++) {
86          if (upper >= trainedTargetPair.Count || (lower > 0 && ev - trainedTargetPair[lower].Key < trainedTargetPair[upper].Key - ev)) {
87            // the nearer neighbor is to the left
88            var lowerClass = trainedTargetPair[lower].Value;
89            if (!neighborClasses.ContainsKey(lowerClass)) neighborClasses[lowerClass] = 1;
90            else neighborClasses[lowerClass]++;
91            lower--;
92          } else {
93            // the nearer neighbor is to the right
94            var upperClass = trainedTargetPair[upper].Value;
95            if (!neighborClasses.ContainsKey(upperClass)) neighborClasses[upperClass] = 1;
96            else neighborClasses[upperClass]++;
97            upper++;
98          }
99        }
100        // majority voting with preference for bigger class in case of tie
101        yield return neighborClasses.MaxItems(x => x.Value).OrderByDescending(x => x.Key, frequencyComparer).First().Key;
102      }
103    }
104
105    public override void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable<int> rows) {
106      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, problemData.Dataset, rows)
107                                       .LimitToRange(LowerEstimationLimit, UpperEstimationLimit);
108      var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
109      var pair = estimatedValues.Zip(targetValues, (e, t) => new { Estimated = e, Target = t });
110
111      // there could be more than one target value per estimated value
112      var dict = new Dictionary<double, Dictionary<double, int>>();
113      var classFrequencies = new Dictionary<double, int>();
114      foreach (var p in pair) {
115        if (!dict.ContainsKey(p.Estimated)) dict[p.Estimated] = new Dictionary<double, int>();
116        if (!dict[p.Estimated].ContainsKey(p.Target)) dict[p.Estimated][p.Target] = 0;
117        dict[p.Estimated][p.Target]++;
118
119        if (!classFrequencies.ContainsKey(p.Target))
120          classFrequencies[p.Target] = 1;
121        else classFrequencies[p.Target]++;
122      }
123
124      frequencyComparer = new ClassFrequencyComparer(classFrequencies);
125
126      trainedTargetPair = new List<KeyValuePair<double, double>>();
127      foreach (var ev in dict) {
128        var target = ev.Value.MaxItems(x => x.Value).OrderByDescending(x => x.Key, frequencyComparer).First().Key;
129        trainedTargetPair.Add(new KeyValuePair<double, double>(ev.Key, target));
130      }
131      trainedTargetPair = trainedTargetPair.OrderBy(x => x.Key).ToList();
132    }
133
134    public override ISymbolicClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
135      return new SymbolicClassificationSolution((ISymbolicClassificationModel)this.Clone(), problemData);
136    }
137  }
138
139  internal class KeyValuePairKeyComparer : IComparer<KeyValuePair<double, double>> {
140    public int Compare(KeyValuePair<double, double> x, KeyValuePair<double, double> y) {
141      return x.Key.CompareTo(y.Key);
142    }
143  }
144
145  [StorableClass]
146  internal class ClassFrequencyComparer : IComparer<double> {
147    [Storable]
148    private Dictionary<double, int> classFrequencies;
149
150    [StorableConstructor]
151    private ClassFrequencyComparer(bool deserializing) { }
152    public ClassFrequencyComparer() {
153      classFrequencies = new Dictionary<double, int>();
154    }
155    public ClassFrequencyComparer(Dictionary<double, int> frequencies) {
156      classFrequencies = frequencies;
157    }
158    public ClassFrequencyComparer(ClassFrequencyComparer original) {
159      classFrequencies = new Dictionary<double, int>(original.classFrequencies);
160    }
161
162    public int Compare(double x, double y) {
163      bool cx = classFrequencies.ContainsKey(x), cy = classFrequencies.ContainsKey(y);
164      if (cx && cy)
165        return classFrequencies[x].CompareTo(classFrequencies[y]);
166      if (cx) return 1;
167      return -1;
168    }
169  }
170}
Note: See TracBrowser for help on using the repository browser.