Free cookie consent management tool by TermsFeed Policy Generator

source: branches/PersistenceOverhaul/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicNearestNeighbourClassificationModel.cs @ 15428

Last change on this file since 15428 was 14711, checked in by gkronber, 8 years ago

#2520

  • renamed StorableClass -> StorableType
  • changed persistence to use GUIDs instead of type names
File size: 7.8 KB
RevLine 
[13368]1#region License Information
[8606]2/* HeuristicLab
[12012]3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[8606]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29
30namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification {
31  /// <summary>
32  /// Represents a nearest neighbour model for regression and classification
33  /// </summary>
[14711]34  [StorableType("225CCF16-C932-4D18-AF5A-0745FAD8F22C")]
[8606]35  [Item("SymbolicNearestNeighbourClassificationModel", "Represents a nearest neighbour model for symbolic classification.")]
36  public sealed class SymbolicNearestNeighbourClassificationModel : SymbolicClassificationModel {
37
38    [Storable]
39    private int k;
40    [Storable]
[9003]41    private List<double> trainedClasses;
[8978]42    [Storable]
[9003]43    private List<double> trainedEstimatedValues;
44
45    [Storable]
[8978]46    private ClassFrequencyComparer frequencyComparer;
[8606]47
48    [StorableConstructor]
49    private SymbolicNearestNeighbourClassificationModel(bool deserializing) : base(deserializing) { }
50    private SymbolicNearestNeighbourClassificationModel(SymbolicNearestNeighbourClassificationModel original, Cloner cloner)
51      : base(original, cloner) {
52      k = original.k;
[8978]53      frequencyComparer = new ClassFrequencyComparer(original.frequencyComparer);
[9003]54      trainedEstimatedValues = new List<double>(original.trainedEstimatedValues);
55      trainedClasses = new List<double>(original.trainedClasses);
[8606]56    }
57    public SymbolicNearestNeighbourClassificationModel(int k, ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue)
58      : base(tree, interpreter, lowerEstimationLimit, upperEstimationLimit) {
59      this.k = k;
[8978]60      frequencyComparer = new ClassFrequencyComparer();
[9003]61
[8606]62    }
63
64    public override IDeepCloneable Clone(Cloner cloner) {
65      return new SymbolicNearestNeighbourClassificationModel(this, cloner);
66    }
67
[12509]68    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
[8978]69      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows)
70                                       .LimitToRange(LowerEstimationLimit, UpperEstimationLimit);
[8606]71      foreach (var ev in estimatedValues) {
[9003]72        // find the range [lower, upper[ of trainedTargetValues that contains the k closest neighbours
73        // the range can span more than k elements when there are equal estimated values
74
[8978]75        // find the index of the training-point to which distance is shortest
[9003]76        int lower = trainedEstimatedValues.BinarySearch(ev);
77        int upper;
78        // if the element was not found exactly, BinarySearch returns the complement of the index of the next larger item
79        if (lower < 0) {
80          lower = ~lower;
81          // lower is not necessarily the closer one
82          // determine which element is closer to ev (lower - 1) or (lower)
83          if (lower == trainedEstimatedValues.Count ||
84            (lower > 0 && Math.Abs(ev - trainedEstimatedValues[lower - 1]) < Math.Abs(ev - trainedEstimatedValues[lower]))) {
85            lower = lower - 1;
86          }
87        }
88        upper = lower + 1;
89        // at this point we have a range [lower, upper[ that includes only the closest element to ev
90
91        // expand the range to left or right looking for the nearest neighbors
92        while (upper - lower < Math.Min(k, trainedEstimatedValues.Count)) {
93          bool lowerIsCloser = upper >= trainedEstimatedValues.Count ||
94                               (lower > 0 && ev - trainedEstimatedValues[lower] <= trainedEstimatedValues[upper] - ev);
95          bool upperIsCloser = lower <= 0 ||
96                               (upper < trainedEstimatedValues.Count &&
97                                ev - trainedEstimatedValues[lower] >= trainedEstimatedValues[upper] - ev);
98          if (!lowerIsCloser && !upperIsCloser) break;
[9002]99          if (lowerIsCloser) {
[8606]100            lower--;
[9003]101            // eat up all equal values
102            while (lower > 0 && trainedEstimatedValues[lower - 1].IsAlmost(trainedEstimatedValues[lower]))
103              lower--;
[9002]104          }
105          if (upperIsCloser) {
[8606]106            upper++;
[9003]107            while (upper < trainedEstimatedValues.Count &&
108                   trainedEstimatedValues[upper - 1].IsAlmost(trainedEstimatedValues[upper]))
109              upper++;
[8606]110          }
111        }
[8978]112        // majority voting with preference for bigger class in case of tie
[9003]113        yield return Enumerable.Range(lower, upper - lower)
114          .Select(i => trainedClasses[i])
115          .GroupBy(c => c)
116          .Select(g => new { Class = g.Key, Votes = g.Count() })
117          .MaxItems(p => p.Votes)
118          .OrderByDescending(m => m.Class, frequencyComparer)
119          .First().Class;
[8606]120      }
121    }
122
123    public override void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable<int> rows) {
[8978]124      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, problemData.Dataset, rows)
125                                       .LimitToRange(LowerEstimationLimit, UpperEstimationLimit);
[8606]126      var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
[9003]127      var trainedClasses = targetValues.ToArray();
128      var trainedEstimatedValues = estimatedValues.ToArray();
[8606]129
[9003]130      Array.Sort(trainedEstimatedValues, trainedClasses);
131      this.trainedClasses = new List<double>(trainedClasses);
132      this.trainedEstimatedValues = new List<double>(trainedEstimatedValues);
[8606]133
[9003]134      var freq = trainedClasses
135        .GroupBy(c => c)
136        .ToDictionary(g => g.Key, g => g.Count());
137      this.frequencyComparer = new ClassFrequencyComparer(freq);
[8606]138    }
139
140    public override ISymbolicClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
[9002]141      return new SymbolicClassificationSolution((ISymbolicClassificationModel)Clone(), problemData);
[8606]142    }
143  }
[8978]144
[14711]145  [StorableType("01561669-12E6-4C75-86BF-88C24DA53FDD")]
[8979]146  internal sealed class ClassFrequencyComparer : IComparer<double> {
[8978]147    [Storable]
[9002]148    private readonly Dictionary<double, int> classFrequencies;
[8978]149
150    [StorableConstructor]
151    private ClassFrequencyComparer(bool deserializing) { }
152    public ClassFrequencyComparer() {
153      classFrequencies = new Dictionary<double, int>();
154    }
155    public ClassFrequencyComparer(Dictionary<double, int> frequencies) {
156      classFrequencies = frequencies;
157    }
158    public ClassFrequencyComparer(ClassFrequencyComparer original) {
159      classFrequencies = new Dictionary<double, int>(original.classFrequencies);
160    }
161
162    public int Compare(double x, double y) {
163      bool cx = classFrequencies.ContainsKey(x), cy = classFrequencies.ContainsKey(y);
164      if (cx && cy)
165        return classFrequencies[x].CompareTo(classFrequencies[y]);
166      if (cx) return 1;
167      return -1;
168    }
169  }
[8606]170}
Note: See TracBrowser for help on using the repository browser.