source: branches/PersistenceOverhaul/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicNearestNeighbourClassificationModel.cs @ 13386

Last change on this file since 13386 was 13386, checked in by ascheibe, 4 years ago

#2520

  • fixed duplicate guids
  • adapted/added unit tests for new persistence
File size: 7.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29
30namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification {
31  /// <summary>
32  /// Represents a nearest neighbour model for regression and classification
33  /// </summary>
34  [StorableClass("225CCF16-C932-4D18-AF5A-0745FAD8F22C")]
35  [Item("SymbolicNearestNeighbourClassificationModel", "Represents a nearest neighbour model for symbolic classification.")]
36  public sealed class SymbolicNearestNeighbourClassificationModel : SymbolicClassificationModel {
37
38    [Storable]
39    private int k;
40    [Storable]
41    private List<double> trainedClasses;
42    [Storable]
43    private List<double> trainedEstimatedValues;
44
45    [Storable]
46    private ClassFrequencyComparer frequencyComparer;
47
48    [StorableConstructor]
49    private SymbolicNearestNeighbourClassificationModel(bool deserializing) : base(deserializing) { }
50    private SymbolicNearestNeighbourClassificationModel(SymbolicNearestNeighbourClassificationModel original, Cloner cloner)
51      : base(original, cloner) {
52      k = original.k;
53      frequencyComparer = new ClassFrequencyComparer(original.frequencyComparer);
54      trainedEstimatedValues = new List<double>(original.trainedEstimatedValues);
55      trainedClasses = new List<double>(original.trainedClasses);
56    }
57    public SymbolicNearestNeighbourClassificationModel(int k, ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue)
58      : base(tree, interpreter, lowerEstimationLimit, upperEstimationLimit) {
59      this.k = k;
60      frequencyComparer = new ClassFrequencyComparer();
61
62    }
63
64    public override IDeepCloneable Clone(Cloner cloner) {
65      return new SymbolicNearestNeighbourClassificationModel(this, cloner);
66    }
67
68    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
69      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows)
70                                       .LimitToRange(LowerEstimationLimit, UpperEstimationLimit);
71      foreach (var ev in estimatedValues) {
72        // find the range [lower, upper[ of trainedTargetValues that contains the k closest neighbours
73        // the range can span more than k elements when there are equal estimated values
74
75        // find the index of the training-point to which distance is shortest
76        int lower = trainedEstimatedValues.BinarySearch(ev);
77        int upper;
78        // if the element was not found exactly, BinarySearch returns the complement of the index of the next larger item
79        if (lower < 0) {
80          lower = ~lower;
81          // lower is not necessarily the closer one
82          // determine which element is closer to ev (lower - 1) or (lower)
83          if (lower == trainedEstimatedValues.Count ||
84            (lower > 0 && Math.Abs(ev - trainedEstimatedValues[lower - 1]) < Math.Abs(ev - trainedEstimatedValues[lower]))) {
85            lower = lower - 1;
86          }
87        }
88        upper = lower + 1;
89        // at this point we have a range [lower, upper[ that includes only the closest element to ev
90
91        // expand the range to left or right looking for the nearest neighbors
92        while (upper - lower < Math.Min(k, trainedEstimatedValues.Count)) {
93          bool lowerIsCloser = upper >= trainedEstimatedValues.Count ||
94                               (lower > 0 && ev - trainedEstimatedValues[lower] <= trainedEstimatedValues[upper] - ev);
95          bool upperIsCloser = lower <= 0 ||
96                               (upper < trainedEstimatedValues.Count &&
97                                ev - trainedEstimatedValues[lower] >= trainedEstimatedValues[upper] - ev);
98          if (!lowerIsCloser && !upperIsCloser) break;
99          if (lowerIsCloser) {
100            lower--;
101            // eat up all equal values
102            while (lower > 0 && trainedEstimatedValues[lower - 1].IsAlmost(trainedEstimatedValues[lower]))
103              lower--;
104          }
105          if (upperIsCloser) {
106            upper++;
107            while (upper < trainedEstimatedValues.Count &&
108                   trainedEstimatedValues[upper - 1].IsAlmost(trainedEstimatedValues[upper]))
109              upper++;
110          }
111        }
112        // majority voting with preference for bigger class in case of tie
113        yield return Enumerable.Range(lower, upper - lower)
114          .Select(i => trainedClasses[i])
115          .GroupBy(c => c)
116          .Select(g => new { Class = g.Key, Votes = g.Count() })
117          .MaxItems(p => p.Votes)
118          .OrderByDescending(m => m.Class, frequencyComparer)
119          .First().Class;
120      }
121    }
122
123    public override void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable<int> rows) {
124      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, problemData.Dataset, rows)
125                                       .LimitToRange(LowerEstimationLimit, UpperEstimationLimit);
126      var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
127      var trainedClasses = targetValues.ToArray();
128      var trainedEstimatedValues = estimatedValues.ToArray();
129
130      Array.Sort(trainedEstimatedValues, trainedClasses);
131      this.trainedClasses = new List<double>(trainedClasses);
132      this.trainedEstimatedValues = new List<double>(trainedEstimatedValues);
133
134      var freq = trainedClasses
135        .GroupBy(c => c)
136        .ToDictionary(g => g.Key, g => g.Count());
137      this.frequencyComparer = new ClassFrequencyComparer(freq);
138    }
139
140    public override ISymbolicClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
141      return new SymbolicClassificationSolution((ISymbolicClassificationModel)Clone(), problemData);
142    }
143  }
144
145  [StorableClass("01561669-12E6-4C75-86BF-88C24DA53FDD")]
146  internal sealed class ClassFrequencyComparer : IComparer<double> {
147    [Storable]
148    private readonly Dictionary<double, int> classFrequencies;
149
150    [StorableConstructor]
151    private ClassFrequencyComparer(bool deserializing) { }
152    public ClassFrequencyComparer() {
153      classFrequencies = new Dictionary<double, int>();
154    }
155    public ClassFrequencyComparer(Dictionary<double, int> frequencies) {
156      classFrequencies = frequencies;
157    }
158    public ClassFrequencyComparer(ClassFrequencyComparer original) {
159      classFrequencies = new Dictionary<double, int>(original.classFrequencies);
160    }
161
162    public int Compare(double x, double y) {
163      bool cx = classFrequencies.ContainsKey(x), cy = classFrequencies.ContainsKey(y);
164      if (cx && cy)
165        return classFrequencies[x].CompareTo(classFrequencies[y]);
166      if (cx) return 1;
167      return -1;
168    }
169  }
170}
Note: See TracBrowser for help on using the repository browser.