Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs @ 6588

Last change on this file since 6588 was 6584, checked in by gkronber, 13 years ago

#763 implemented persistence for k nearest neighbour models.

File size: 10.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.IO;
25using System.Linq;
26using System.Text;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Problems.DataAnalysis;
31using SVM;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  /// <summary>
35  /// Represents a nearest neighbour model for regression and classification
36  /// </summary>
37  [StorableClass]
38  [Item("NearestNeighbourModel", "Represents a neural network for regression and classification.")]
39  public sealed class NearestNeighbourModel : NamedItem, INearestNeighbourModel {
40
41    private alglib.nearestneighbor.kdtree kdTree;
42    public alglib.nearestneighbor.kdtree KDTree {
43      get { return kdTree; }
44      set {
45        if (value != kdTree) {
46          if (value == null) throw new ArgumentNullException();
47          kdTree = value;
48          OnChanged(EventArgs.Empty);
49        }
50      }
51    }
52
53    [Storable]
54    private string targetVariable;
55    [Storable]
56    private string[] allowedInputVariables;
57    [Storable]
58    private double[] classValues;
59    [Storable]
60    private int k;
61    [StorableConstructor]
62    private NearestNeighbourModel(bool deserializing)
63      : base(deserializing) {
64      if (deserializing)
65        kdTree = new alglib.nearestneighbor.kdtree();
66    }
67    private NearestNeighbourModel(NearestNeighbourModel original, Cloner cloner)
68      : base(original, cloner) {
69      kdTree = new alglib.nearestneighbor.kdtree();
70      kdTree.approxf = original.kdTree.approxf;
71      kdTree.boxmax = (double[])original.kdTree.boxmax.Clone();
72      kdTree.boxmin = (double[])original.kdTree.boxmin.Clone();
73      kdTree.buf = (double[])original.kdTree.buf.Clone();
74      kdTree.curboxmax = (double[])original.kdTree.curboxmax.Clone();
75      kdTree.curboxmin = (double[])original.kdTree.curboxmin.Clone();
76      kdTree.curdist = original.kdTree.curdist;
77      kdTree.debugcounter = original.kdTree.debugcounter;
78      kdTree.distmatrixtype = original.kdTree.distmatrixtype;
79      kdTree.idx = (int[])original.kdTree.idx.Clone();
80      kdTree.kcur = original.kdTree.kcur;
81      kdTree.kneeded = original.kdTree.kneeded;
82      kdTree.n = original.kdTree.n;
83      kdTree.nodes = (int[])original.kdTree.nodes.Clone();
84      kdTree.normtype = original.kdTree.normtype;
85      kdTree.nx = original.kdTree.nx;
86      kdTree.ny = original.kdTree.ny;
87      kdTree.r = (double[])original.kdTree.r.Clone();
88      kdTree.rneeded = original.kdTree.rneeded;
89      kdTree.selfmatch = original.kdTree.selfmatch;
90      kdTree.splits = (double[])original.kdTree.splits.Clone();
91      kdTree.tags = (int[])original.kdTree.tags.Clone();
92      kdTree.x = (double[])original.kdTree.x.Clone();
93      kdTree.xy = (double[,])original.kdTree.xy.Clone();
94
95      k = original.k;
96      targetVariable = original.targetVariable;
97      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
98      if (original.classValues != null)
99        this.classValues = (double[])original.classValues.Clone();
100    }
101    public NearestNeighbourModel(alglib.nearestneighbor.kdtree kdTree, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
102      : base() {
103      this.name = ItemName;
104      this.description = ItemDescription;
105      this.kdTree = kdTree;
106      this.k = k;
107      this.targetVariable = targetVariable;
108      this.allowedInputVariables = allowedInputVariables.ToArray();
109      if (classValues != null)
110        this.classValues = (double[])classValues.Clone();
111    }
112
113    public override IDeepCloneable Clone(Cloner cloner) {
114      return new NearestNeighbourModel(this, cloner);
115    }
116
117    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
118      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
119
120      int n = inputData.GetLength(0);
121      int columns = inputData.GetLength(1);
122      double[] x = new double[columns];
123      double[] y = new double[1];
124      double[] dists = new double[k];
125      double[,] neighbours = new double[k, columns + 1];
126
127      for (int row = 0; row < n; row++) {
128        for (int column = 0; column < columns; column++) {
129          x[column] = inputData[row, column];
130        }
131        int actNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);
132        alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
133        alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
134
135        double distanceWeightedValue = 0.0;
136        double distsSum = 0.0;
137        for (int i = 0; i < actNeighbours; i++) {
138          distanceWeightedValue += neighbours[i, columns] / dists[i];
139          distsSum += 1.0 / dists[i];
140        }
141        yield return distanceWeightedValue / distsSum;
142      }
143    }
144
145    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
146      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
147
148      int n = inputData.GetLength(0);
149      int columns = inputData.GetLength(1);
150      double[] x = new double[columns];
151      int[] y = new int[classValues.Length];
152      double[] dists = new double[k];
153      double[,] neighbours = new double[k, columns + 1];
154
155      for (int row = 0; row < n; row++) {
156        for (int column = 0; column < columns; column++) {
157          x[column] = inputData[row, column];
158        }
159        int actNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);
160        alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
161        alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
162
163        Array.Clear(y, 0, y.Length);
164        for (int i = 0; i < actNeighbours; i++) {
165          int classValue = (int)Math.Round(neighbours[i, columns]);
166          y[classValue]++;
167        }
168
169        // find class for with the largest probability value
170        int maxProbClassIndex = 0;
171        double maxProb = y[0];
172        for (int i = 1; i < y.Length; i++) {
173          if (maxProb < y[i]) {
174            maxProb = y[i];
175            maxProbClassIndex = i;
176          }
177        }
178        yield return classValues[maxProbClassIndex];
179      }
180    }
181
182    #region events
183    public event EventHandler Changed;
184    private void OnChanged(EventArgs e) {
185      var handlers = Changed;
186      if (handlers != null)
187        handlers(this, e);
188    }
189    #endregion
190
191    #region persistence
192    [Storable]
193    public double KDTreeApproxF {
194      get { return kdTree.approxf; }
195      set { kdTree.approxf = value; }
196    }
197    [Storable]
198    public double[] KDTreeBoxMax {
199      get { return kdTree.boxmax; }
200      set { kdTree.boxmax = value; }
201    }
202    [Storable]
203    public double[] KDTreeBoxMin {
204      get { return kdTree.boxmin; }
205      set { kdTree.boxmin = value; }
206    }
207    [Storable]
208    public double[] KDTreeBuf {
209      get { return kdTree.buf; }
210      set { kdTree.buf = value; }
211    }
212    [Storable]
213    public double[] KDTreeCurBoxMax {
214      get { return kdTree.curboxmax; }
215      set { kdTree.curboxmax = value; }
216    }
217    [Storable]
218    public double[] KDTreeCurBoxMin {
219      get { return kdTree.curboxmin; }
220      set { kdTree.curboxmin = value; }
221    }
222    [Storable]
223    public double KDTreeCurDist {
224      get { return kdTree.curdist; }
225      set { kdTree.curdist = value; }
226    }
227    [Storable]
228    public int KDTreeDebugCounter {
229      get { return kdTree.debugcounter; }
230      set { kdTree.debugcounter = value; }
231    }
232    [Storable]
233    public int KDTreeDistMatrixType {
234      get { return kdTree.distmatrixtype; }
235      set { kdTree.distmatrixtype = value; }
236    }
237    [Storable]
238    public int[] KDTreeIdx {
239      get { return kdTree.idx; }
240      set { kdTree.idx = value; }
241    }
242    [Storable]
243    public int KDTreeKCur {
244      get { return kdTree.kcur; }
245      set { kdTree.kcur = value; }
246    }
247    [Storable]
248    public int KDTreeKNeeded {
249      get { return kdTree.kneeded; }
250      set { kdTree.kneeded = value; }
251    }
252    [Storable]
253    public int KDTreeN {
254      get { return kdTree.n; }
255      set { kdTree.n = value; }
256    }
257    [Storable]
258    public int[] KDTreeNodes {
259      get { return kdTree.nodes; }
260      set { kdTree.nodes = value; }
261    }
262    [Storable]
263    public int KDTreeNormType {
264      get { return kdTree.normtype; }
265      set { kdTree.normtype = value; }
266    }
267    [Storable]
268    public int KDTreeNX {
269      get { return kdTree.nx; }
270      set { kdTree.nx = value; }
271    }
272    [Storable]
273    public int KDTreeNY {
274      get { return kdTree.ny; }
275      set { kdTree.ny = value; }
276    }
277    [Storable]
278    public double[] KDTreeR {
279      get { return kdTree.r; }
280      set { kdTree.r = value; }
281    }
282    [Storable]
283    public double KDTreeRNeeded {
284      get { return kdTree.rneeded; }
285      set { kdTree.rneeded = value; }
286    }
287    [Storable]
288    public bool KDTreeSelfMatch {
289      get { return kdTree.selfmatch; }
290      set { kdTree.selfmatch = value; }
291    }
292    [Storable]
293    public double[] KDTreeSplits {
294      get { return kdTree.splits; }
295      set { kdTree.splits = value; }
296    }
297    [Storable]
298    public int[] KDTreeTags {
299      get { return kdTree.tags; }
300      set { kdTree.tags = value; }
301    }
302    [Storable]
303    public double[] KDTreeX {
304      get { return kdTree.x; }
305      set { kdTree.x = value; }
306    }
307    [Storable]
308    public double[,] KDTreeXY {
309      get { return kdTree.xy; }
310      set { kdTree.xy = value; }
311    }
312    #endregion
313  }
314}
Note: See TracBrowser for help on using the repository browser.