Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GP-MoveOperators/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicNearestNeighbourClassificationModel.cs @ 12147

Last change on this file since 12147 was 8660, checked in by gkronber, 12 years ago

#1847 merged r8205:8635 from trunk into branch

File size: 5.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29
30namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification {
31  /// <summary>
32  /// Represents a nearest neighbour model for regression and classification
33  /// </summary>
34  [StorableClass]
35  [Item("SymbolicNearestNeighbourClassificationModel", "Represents a nearest neighbour model for symbolic classification.")]
36  public sealed class SymbolicNearestNeighbourClassificationModel : SymbolicClassificationModel {
37
38    [Storable]
39    private int k;
40    [Storable]
41    private List<KeyValuePair<double, double>> trainedTargetPair;
42
43    [StorableConstructor]
44    private SymbolicNearestNeighbourClassificationModel(bool deserializing) : base(deserializing) { }
45    private SymbolicNearestNeighbourClassificationModel(SymbolicNearestNeighbourClassificationModel original, Cloner cloner)
46      : base(original, cloner) {
47      k = original.k;
48      trainedTargetPair = new List<KeyValuePair<double, double>>(original.trainedTargetPair);
49    }
50    public SymbolicNearestNeighbourClassificationModel(int k, ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue)
51      : base(tree, interpreter, lowerEstimationLimit, upperEstimationLimit) {
52      this.k = k;
53      this.trainedTargetPair = new List<KeyValuePair<double, double>>();
54    }
55
56    public override IDeepCloneable Clone(Cloner cloner) {
57      return new SymbolicNearestNeighbourClassificationModel(this, cloner);
58    }
59
60    public override IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
61      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows);
62      var neighbors = new Dictionary<double, int>();
63      foreach (var ev in estimatedValues) {
64        int lower = 0, upper = 1;
65        double sdist = Math.Abs(ev - trainedTargetPair[0].Key);
66        for (int i = 1; i < trainedTargetPair.Count; i++) {
67          double d = Math.Abs(ev - trainedTargetPair[i].Key);
68          if (d > sdist) break;
69          lower = i;
70          upper = i + 1;
71          sdist = d;
72        }
73        neighbors.Clear();
74        neighbors[trainedTargetPair[lower].Value] = 1;
75        lower--;
76        for (int i = 1; i < Math.Min(k, trainedTargetPair.Count); i++) {
77          if (upper >= trainedTargetPair.Count || (lower > 0 && ev - trainedTargetPair[lower].Key < trainedTargetPair[upper].Key - ev)) {
78            if (!neighbors.ContainsKey(trainedTargetPair[lower].Value))
79              neighbors[trainedTargetPair[lower].Value] = 1;
80            else neighbors[trainedTargetPair[lower].Value]++;
81            lower--;
82          } else {
83            if (!neighbors.ContainsKey(trainedTargetPair[upper].Value))
84              neighbors[trainedTargetPair[upper].Value] = 1;
85            else neighbors[trainedTargetPair[upper].Value]++;
86            upper++;
87          }
88        }
89        yield return neighbors.MaxItems(x => x.Value).First().Key;
90      }
91    }
92
93    public override void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable<int> rows) {
94      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, problemData.Dataset, rows);
95      var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
96      var pair = estimatedValues.Zip(targetValues, (e, t) => new { Estimated = e, Target = t });
97
98      // there could be more than one target value per estimated value
99      var dict = new Dictionary<double, Dictionary<double, int>>();
100      foreach (var p in pair) {
101        if (!dict.ContainsKey(p.Estimated)) dict[p.Estimated] = new Dictionary<double, int>();
102        if (!dict[p.Estimated].ContainsKey(p.Target)) dict[p.Estimated][p.Target] = 0;
103        dict[p.Estimated][p.Target]++;
104      }
105
106      trainedTargetPair = new List<KeyValuePair<double, double>>();
107      foreach (var ev in dict) {
108        var target = ev.Value.MaxItems(x => x.Value).First().Key;
109        trainedTargetPair.Add(new KeyValuePair<double, double>(ev.Key, target));
110      }
111      trainedTargetPair = trainedTargetPair.OrderBy(x => x.Key).ToList();
112    }
113
114    public override ISymbolicClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
115      return new SymbolicClassificationSolution((ISymbolicClassificationModel)this.Clone(), problemData);
116    }
117  }
118}
Note: See TracBrowser for help on using the repository browser.