source: branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs @ 14237

Last change on this file since 14237 was 14237, checked in by gkronber, 3 years ago

#2650: work in progress..

File size: 12.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
32  /// Represents a nearest neighbour model for regression and classification
33  /// </summary>
34  [StorableClass]
35  [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")]
36  public sealed class NearestNeighbourModel : ClassificationModel, INearestNeighbourModel {
37
38    private alglib.nearestneighbor.kdtree kdTree;
39    public alglib.nearestneighbor.kdtree KDTree {
40      get { return kdTree; }
41      set {
42        if (value != kdTree) {
43          if (value == null) throw new ArgumentNullException();
44          kdTree = value;
45          OnChanged(EventArgs.Empty);
46        }
47      }
48    }
49
50    public override IEnumerable<string> VariablesUsedForPrediction {
51      get { return allowedInputVariables; }
52    }
53
54    [Storable]
55    private string[] allowedInputVariables;
56    [Storable]
57    private double[] classValues;
58    [Storable]
59    private int k;
60
61    [StorableConstructor]
62    private NearestNeighbourModel(bool deserializing)
63      : base(deserializing) {
64      if (deserializing)
65        kdTree = new alglib.nearestneighbor.kdtree();
66    }
67    private NearestNeighbourModel(NearestNeighbourModel original, Cloner cloner)
68      : base(original, cloner) {
69      kdTree = new alglib.nearestneighbor.kdtree();
70      kdTree.approxf = original.kdTree.approxf;
71      kdTree.boxmax = (double[])original.kdTree.boxmax.Clone();
72      kdTree.boxmin = (double[])original.kdTree.boxmin.Clone();
73      kdTree.buf = (double[])original.kdTree.buf.Clone();
74      kdTree.curboxmax = (double[])original.kdTree.curboxmax.Clone();
75      kdTree.curboxmin = (double[])original.kdTree.curboxmin.Clone();
76      kdTree.curdist = original.kdTree.curdist;
77      kdTree.debugcounter = original.kdTree.debugcounter;
78      kdTree.idx = (int[])original.kdTree.idx.Clone();
79      kdTree.kcur = original.kdTree.kcur;
80      kdTree.kneeded = original.kdTree.kneeded;
81      kdTree.n = original.kdTree.n;
82      kdTree.nodes = (int[])original.kdTree.nodes.Clone();
83      kdTree.normtype = original.kdTree.normtype;
84      kdTree.nx = original.kdTree.nx;
85      kdTree.ny = original.kdTree.ny;
86      kdTree.r = (double[])original.kdTree.r.Clone();
87      kdTree.rneeded = original.kdTree.rneeded;
88      kdTree.selfmatch = original.kdTree.selfmatch;
89      kdTree.splits = (double[])original.kdTree.splits.Clone();
90      kdTree.tags = (int[])original.kdTree.tags.Clone();
91      kdTree.x = (double[])original.kdTree.x.Clone();
92      kdTree.xy = (double[,])original.kdTree.xy.Clone();
93
94      k = original.k;
95      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
96      if (original.classValues != null)
97        this.classValues = (double[])original.classValues.Clone();
98    }
99    public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
100      : base(targetVariable) {
101      Name = ItemName;
102      Description = ItemDescription;
103      this.k = k;
104      this.allowedInputVariables = allowedInputVariables.ToArray();
105
106      // check input variables. Only double variables are allowed.
107      var invalidInputs =
108        allowedInputVariables.Where(name => !dataset.VariableHasType<double>(name));
109      if (invalidInputs.Any())
110        throw new NotSupportedException("Gradient tree boosting only supports real-valued variables. Unsupported inputs: " + string.Join(", ", invalidInputs));
111
112
113      var inputMatrix = AlglibUtil.PrepareInputMatrix(dataset,
114                                   allowedInputVariables.Concat(new string[] { targetVariable }),
115                                   rows);
116
117      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
118        throw new NotSupportedException(
119          "Nearest neighbour classification does not support NaN or infinity values in the input dataset.");
120
121      this.kdTree = new alglib.nearestneighbor.kdtree();
122
123      var nRows = inputMatrix.GetLength(0);
124      var nFeatures = inputMatrix.GetLength(1) - 1;
125
126      if (classValues != null) {
127        this.classValues = (double[])classValues.Clone();
128        int nClasses = classValues.Length;
129        // map original class values to values [0..nClasses-1]
130        var classIndices = new Dictionary<double, double>();
131        for (int i = 0; i < nClasses; i++)
132          classIndices[classValues[i]] = i;
133
134        for (int row = 0; row < nRows; row++) {
135          inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
136        }
137      }
138      alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdTree);
139    }
140
141    public override IDeepCloneable Clone(Cloner cloner) {
142      return new NearestNeighbourModel(this, cloner);
143    }
144
145    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
146      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
147
148      int n = inputData.GetLength(0);
149      int columns = inputData.GetLength(1);
150      double[] x = new double[columns];
151      double[] y = new double[1];
152      double[] dists = new double[k];
153      double[,] neighbours = new double[k, columns + 1];
154
155      for (int row = 0; row < n; row++) {
156        for (int column = 0; column < columns; column++) {
157          x[column] = inputData[row, column];
158        }
159        int actNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);
160        alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
161        alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
162
163        double distanceWeightedValue = 0.0;
164        double distsSum = 0.0;
165        for (int i = 0; i < actNeighbours; i++) {
166          distanceWeightedValue += neighbours[i, columns] / dists[i];
167          distsSum += 1.0 / dists[i];
168        }
169        yield return distanceWeightedValue / distsSum;
170      }
171    }
172
173    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
174      if (classValues == null) throw new InvalidOperationException("No class values are defined.");
175      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
176
177      int n = inputData.GetLength(0);
178      int columns = inputData.GetLength(1);
179      double[] x = new double[columns];
180      int[] y = new int[classValues.Length];
181      double[] dists = new double[k];
182      double[,] neighbours = new double[k, columns + 1];
183
184      for (int row = 0; row < n; row++) {
185        for (int column = 0; column < columns; column++) {
186          x[column] = inputData[row, column];
187        }
188        int actNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);
189        alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
190        alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
191
192        Array.Clear(y, 0, y.Length);
193        for (int i = 0; i < actNeighbours; i++) {
194          int classValue = (int)Math.Round(neighbours[i, columns]);
195          y[classValue]++;
196        }
197
198        // find class for with the largest probability value
199        int maxProbClassIndex = 0;
200        double maxProb = y[0];
201        for (int i = 1; i < y.Length; i++) {
202          if (maxProb < y[i]) {
203            maxProb = y[i];
204            maxProbClassIndex = i;
205          }
206        }
207        yield return classValues[maxProbClassIndex];
208      }
209    }
210
211
212    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
213      return new NearestNeighbourRegressionSolution(this, new RegressionProblemData(problemData));
214    }
215    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
216      return new NearestNeighbourClassificationSolution(this, new ClassificationProblemData(problemData));
217    }
218
219    #region events
220    public event EventHandler Changed;
221    private void OnChanged(EventArgs e) {
222      var handlers = Changed;
223      if (handlers != null)
224        handlers(this, e);
225    }
226    #endregion
227
228    #region persistence
229    [Storable]
230    public double KDTreeApproxF {
231      get { return kdTree.approxf; }
232      set { kdTree.approxf = value; }
233    }
234    [Storable]
235    public double[] KDTreeBoxMax {
236      get { return kdTree.boxmax; }
237      set { kdTree.boxmax = value; }
238    }
239    [Storable]
240    public double[] KDTreeBoxMin {
241      get { return kdTree.boxmin; }
242      set { kdTree.boxmin = value; }
243    }
244    [Storable]
245    public double[] KDTreeBuf {
246      get { return kdTree.buf; }
247      set { kdTree.buf = value; }
248    }
249    [Storable]
250    public double[] KDTreeCurBoxMax {
251      get { return kdTree.curboxmax; }
252      set { kdTree.curboxmax = value; }
253    }
254    [Storable]
255    public double[] KDTreeCurBoxMin {
256      get { return kdTree.curboxmin; }
257      set { kdTree.curboxmin = value; }
258    }
259    [Storable]
260    public double KDTreeCurDist {
261      get { return kdTree.curdist; }
262      set { kdTree.curdist = value; }
263    }
264    [Storable]
265    public int KDTreeDebugCounter {
266      get { return kdTree.debugcounter; }
267      set { kdTree.debugcounter = value; }
268    }
269    [Storable]
270    public int[] KDTreeIdx {
271      get { return kdTree.idx; }
272      set { kdTree.idx = value; }
273    }
274    [Storable]
275    public int KDTreeKCur {
276      get { return kdTree.kcur; }
277      set { kdTree.kcur = value; }
278    }
279    [Storable]
280    public int KDTreeKNeeded {
281      get { return kdTree.kneeded; }
282      set { kdTree.kneeded = value; }
283    }
284    [Storable]
285    public int KDTreeN {
286      get { return kdTree.n; }
287      set { kdTree.n = value; }
288    }
289    [Storable]
290    public int[] KDTreeNodes {
291      get { return kdTree.nodes; }
292      set { kdTree.nodes = value; }
293    }
294    [Storable]
295    public int KDTreeNormType {
296      get { return kdTree.normtype; }
297      set { kdTree.normtype = value; }
298    }
299    [Storable]
300    public int KDTreeNX {
301      get { return kdTree.nx; }
302      set { kdTree.nx = value; }
303    }
304    [Storable]
305    public int KDTreeNY {
306      get { return kdTree.ny; }
307      set { kdTree.ny = value; }
308    }
309    [Storable]
310    public double[] KDTreeR {
311      get { return kdTree.r; }
312      set { kdTree.r = value; }
313    }
314    [Storable]
315    public double KDTreeRNeeded {
316      get { return kdTree.rneeded; }
317      set { kdTree.rneeded = value; }
318    }
319    [Storable]
320    public bool KDTreeSelfMatch {
321      get { return kdTree.selfmatch; }
322      set { kdTree.selfmatch = value; }
323    }
324    [Storable]
325    public double[] KDTreeSplits {
326      get { return kdTree.splits; }
327      set { kdTree.splits = value; }
328    }
329    [Storable]
330    public int[] KDTreeTags {
331      get { return kdTree.tags; }
332      set { kdTree.tags = value; }
333    }
334    [Storable]
335    public double[] KDTreeX {
336      get { return kdTree.x; }
337      set { kdTree.x = value; }
338    }
339    [Storable]
340    public double[,] KDTreeXY {
341      get { return kdTree.xy; }
342      set { kdTree.xy = value; }
343    }
344    #endregion
345  }
346}
Note: See TracBrowser for help on using the repository browser.