Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs @ 12588

Last change on this file since 12588 was 12588, checked in by gkronber, 9 years ago

#2261: merged changes from trunk

File size: 12.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
32  /// Represents a nearest neighbour model for regression and classification
33  /// </summary>
34  [StorableClass]
35  [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")]
36  public sealed class NearestNeighbourModel : NamedItem, INearestNeighbourModel {
37
38    private alglib.nearestneighbor.kdtree kdTree;
39    public alglib.nearestneighbor.kdtree KDTree {
40      get { return kdTree; }
41      set {
42        if (value != kdTree) {
43          if (value == null) throw new ArgumentNullException();
44          kdTree = value;
45          OnChanged(EventArgs.Empty);
46        }
47      }
48    }
49
50    [Storable]
51    private string targetVariable;
52    [Storable]
53    private string[] allowedInputVariables;
54    [Storable]
55    private double[] classValues;
56    [Storable]
57    private int k;
58
59    [StorableConstructor]
60    private NearestNeighbourModel(bool deserializing)
61      : base(deserializing) {
62      if (deserializing)
63        kdTree = new alglib.nearestneighbor.kdtree();
64    }
65    private NearestNeighbourModel(NearestNeighbourModel original, Cloner cloner)
66      : base(original, cloner) {
67      kdTree = new alglib.nearestneighbor.kdtree();
68      kdTree.approxf = original.kdTree.approxf;
69      kdTree.boxmax = (double[])original.kdTree.boxmax.Clone();
70      kdTree.boxmin = (double[])original.kdTree.boxmin.Clone();
71      kdTree.buf = (double[])original.kdTree.buf.Clone();
72      kdTree.curboxmax = (double[])original.kdTree.curboxmax.Clone();
73      kdTree.curboxmin = (double[])original.kdTree.curboxmin.Clone();
74      kdTree.curdist = original.kdTree.curdist;
75      kdTree.debugcounter = original.kdTree.debugcounter;
76      kdTree.idx = (int[])original.kdTree.idx.Clone();
77      kdTree.kcur = original.kdTree.kcur;
78      kdTree.kneeded = original.kdTree.kneeded;
79      kdTree.n = original.kdTree.n;
80      kdTree.nodes = (int[])original.kdTree.nodes.Clone();
81      kdTree.normtype = original.kdTree.normtype;
82      kdTree.nx = original.kdTree.nx;
83      kdTree.ny = original.kdTree.ny;
84      kdTree.r = (double[])original.kdTree.r.Clone();
85      kdTree.rneeded = original.kdTree.rneeded;
86      kdTree.selfmatch = original.kdTree.selfmatch;
87      kdTree.splits = (double[])original.kdTree.splits.Clone();
88      kdTree.tags = (int[])original.kdTree.tags.Clone();
89      kdTree.x = (double[])original.kdTree.x.Clone();
90      kdTree.xy = (double[,])original.kdTree.xy.Clone();
91
92      k = original.k;
93      targetVariable = original.targetVariable;
94      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
95      if (original.classValues != null)
96        this.classValues = (double[])original.classValues.Clone();
97    }
98    public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null) {
99      Name = ItemName;
100      Description = ItemDescription;
101      this.k = k;
102      this.targetVariable = targetVariable;
103      this.allowedInputVariables = allowedInputVariables.ToArray();
104
105      var inputMatrix = AlglibUtil.PrepareInputMatrix(dataset,
106                                   allowedInputVariables.Concat(new string[] { targetVariable }),
107                                   rows);
108
109      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
110        throw new NotSupportedException(
111          "Nearest neighbour classification does not support NaN or infinity values in the input dataset.");
112
113      this.kdTree = new alglib.nearestneighbor.kdtree();
114
115      var nRows = inputMatrix.GetLength(0);
116      var nFeatures = inputMatrix.GetLength(1) - 1;
117
118      if (classValues != null) {
119        this.classValues = (double[])classValues.Clone();
120        int nClasses = classValues.Length;
121        // map original class values to values [0..nClasses-1]
122        var classIndices = new Dictionary<double, double>();
123        for (int i = 0; i < nClasses; i++)
124          classIndices[classValues[i]] = i;
125
126        for (int row = 0; row < nRows; row++) {
127          inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
128        }
129      }
130      alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdTree);
131    }
132
133    public override IDeepCloneable Clone(Cloner cloner) {
134      return new NearestNeighbourModel(this, cloner);
135    }
136
137    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
138      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
139
140      int n = inputData.GetLength(0);
141      int columns = inputData.GetLength(1);
142      double[] x = new double[columns];
143      double[] y = new double[1];
144      double[] dists = new double[k];
145      double[,] neighbours = new double[k, columns + 1];
146
147      for (int row = 0; row < n; row++) {
148        for (int column = 0; column < columns; column++) {
149          x[column] = inputData[row, column];
150        }
151        int actNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);
152        alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
153        alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
154
155        double distanceWeightedValue = 0.0;
156        double distsSum = 0.0;
157        for (int i = 0; i < actNeighbours; i++) {
158          distanceWeightedValue += neighbours[i, columns] / dists[i];
159          distsSum += 1.0 / dists[i];
160        }
161        yield return distanceWeightedValue / distsSum;
162      }
163    }
164
165    public IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
166      if (classValues == null) throw new InvalidOperationException("No class values are defined.");
167      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
168
169      int n = inputData.GetLength(0);
170      int columns = inputData.GetLength(1);
171      double[] x = new double[columns];
172      int[] y = new int[classValues.Length];
173      double[] dists = new double[k];
174      double[,] neighbours = new double[k, columns + 1];
175
176      for (int row = 0; row < n; row++) {
177        for (int column = 0; column < columns; column++) {
178          x[column] = inputData[row, column];
179        }
180        int actNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);
181        alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
182        alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
183
184        Array.Clear(y, 0, y.Length);
185        for (int i = 0; i < actNeighbours; i++) {
186          int classValue = (int)Math.Round(neighbours[i, columns]);
187          y[classValue]++;
188        }
189
190        // find class for with the largest probability value
191        int maxProbClassIndex = 0;
192        double maxProb = y[0];
193        for (int i = 1; i < y.Length; i++) {
194          if (maxProb < y[i]) {
195            maxProb = y[i];
196            maxProbClassIndex = i;
197          }
198        }
199        yield return classValues[maxProbClassIndex];
200      }
201    }
202
203    public INearestNeighbourRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
204      return new NearestNeighbourRegressionSolution(new RegressionProblemData(problemData), this);
205    }
206    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
207      return CreateRegressionSolution(problemData);
208    }
209    public INearestNeighbourClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
210      return new NearestNeighbourClassificationSolution(new ClassificationProblemData(problemData), this);
211    }
212    IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
213      return CreateClassificationSolution(problemData);
214    }
215
216    #region events
217    public event EventHandler Changed;
218    private void OnChanged(EventArgs e) {
219      var handlers = Changed;
220      if (handlers != null)
221        handlers(this, e);
222    }
223    #endregion
224
225    #region persistence
226    [Storable]
227    public double KDTreeApproxF {
228      get { return kdTree.approxf; }
229      set { kdTree.approxf = value; }
230    }
231    [Storable]
232    public double[] KDTreeBoxMax {
233      get { return kdTree.boxmax; }
234      set { kdTree.boxmax = value; }
235    }
236    [Storable]
237    public double[] KDTreeBoxMin {
238      get { return kdTree.boxmin; }
239      set { kdTree.boxmin = value; }
240    }
241    [Storable]
242    public double[] KDTreeBuf {
243      get { return kdTree.buf; }
244      set { kdTree.buf = value; }
245    }
246    [Storable]
247    public double[] KDTreeCurBoxMax {
248      get { return kdTree.curboxmax; }
249      set { kdTree.curboxmax = value; }
250    }
251    [Storable]
252    public double[] KDTreeCurBoxMin {
253      get { return kdTree.curboxmin; }
254      set { kdTree.curboxmin = value; }
255    }
256    [Storable]
257    public double KDTreeCurDist {
258      get { return kdTree.curdist; }
259      set { kdTree.curdist = value; }
260    }
261    [Storable]
262    public int KDTreeDebugCounter {
263      get { return kdTree.debugcounter; }
264      set { kdTree.debugcounter = value; }
265    }
266    [Storable]
267    public int[] KDTreeIdx {
268      get { return kdTree.idx; }
269      set { kdTree.idx = value; }
270    }
271    [Storable]
272    public int KDTreeKCur {
273      get { return kdTree.kcur; }
274      set { kdTree.kcur = value; }
275    }
276    [Storable]
277    public int KDTreeKNeeded {
278      get { return kdTree.kneeded; }
279      set { kdTree.kneeded = value; }
280    }
281    [Storable]
282    public int KDTreeN {
283      get { return kdTree.n; }
284      set { kdTree.n = value; }
285    }
286    [Storable]
287    public int[] KDTreeNodes {
288      get { return kdTree.nodes; }
289      set { kdTree.nodes = value; }
290    }
291    [Storable]
292    public int KDTreeNormType {
293      get { return kdTree.normtype; }
294      set { kdTree.normtype = value; }
295    }
296    [Storable]
297    public int KDTreeNX {
298      get { return kdTree.nx; }
299      set { kdTree.nx = value; }
300    }
301    [Storable]
302    public int KDTreeNY {
303      get { return kdTree.ny; }
304      set { kdTree.ny = value; }
305    }
306    [Storable]
307    public double[] KDTreeR {
308      get { return kdTree.r; }
309      set { kdTree.r = value; }
310    }
311    [Storable]
312    public double KDTreeRNeeded {
313      get { return kdTree.rneeded; }
314      set { kdTree.rneeded = value; }
315    }
316    [Storable]
317    public bool KDTreeSelfMatch {
318      get { return kdTree.selfmatch; }
319      set { kdTree.selfmatch = value; }
320    }
321    [Storable]
322    public double[] KDTreeSplits {
323      get { return kdTree.splits; }
324      set { kdTree.splits = value; }
325    }
326    [Storable]
327    public int[] KDTreeTags {
328      get { return kdTree.tags; }
329      set { kdTree.tags = value; }
330    }
331    [Storable]
332    public double[] KDTreeX {
333      get { return kdTree.x; }
334      set { kdTree.x = value; }
335    }
336    [Storable]
337    public double[,] KDTreeXY {
338      get { return kdTree.xy; }
339      set { kdTree.xy = value; }
340    }
341    #endregion
342  }
343}
Note: See TracBrowser for help on using the repository browser.