Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs @ 14235

Last change on this file since 14235 was 14235, checked in by gkronber, 8 years ago

#2652: added scaling and optional specification of feature-weights for kNN

File size: 14.9 KB
RevLine 
[6583]1#region License Information
2/* HeuristicLab
[14185]3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[6583]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
32  /// Represents a nearest neighbour model for regression and classification
33  /// </summary>
34  [StorableClass]
[8465]35  [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")]
[13941]36  public sealed class NearestNeighbourModel : ClassificationModel, INearestNeighbourModel {
[6583]37
38    private alglib.nearestneighbor.kdtree kdTree;
39    public alglib.nearestneighbor.kdtree KDTree {
40      get { return kdTree; }
41      set {
42        if (value != kdTree) {
43          if (value == null) throw new ArgumentNullException();
44          kdTree = value;
45          OnChanged(EventArgs.Empty);
46        }
47      }
48    }
49
[13941]50    public override IEnumerable<string> VariablesUsedForPrediction {
[13921]51      get { return allowedInputVariables; }
52    }
53
[6583]54    [Storable]
55    private string[] allowedInputVariables;
56    [Storable]
57    private double[] classValues;
58    [Storable]
59    private int k;
[14235]60    [Storable(DefaultValue = null)]
61    private double[] weights; // not set for old versions loaded from disk
62    [Storable(DefaultValue = null)]
63    private double[] offsets; // not set for old versions loaded from disk
[8465]64
[6583]65    [StorableConstructor]
66    private NearestNeighbourModel(bool deserializing)
67      : base(deserializing) {
68      if (deserializing)
69        kdTree = new alglib.nearestneighbor.kdtree();
70    }
71    private NearestNeighbourModel(NearestNeighbourModel original, Cloner cloner)
72      : base(original, cloner) {
73      kdTree = new alglib.nearestneighbor.kdtree();
74      kdTree.approxf = original.kdTree.approxf;
75      kdTree.boxmax = (double[])original.kdTree.boxmax.Clone();
76      kdTree.boxmin = (double[])original.kdTree.boxmin.Clone();
77      kdTree.buf = (double[])original.kdTree.buf.Clone();
78      kdTree.curboxmax = (double[])original.kdTree.curboxmax.Clone();
79      kdTree.curboxmin = (double[])original.kdTree.curboxmin.Clone();
80      kdTree.curdist = original.kdTree.curdist;
81      kdTree.debugcounter = original.kdTree.debugcounter;
82      kdTree.idx = (int[])original.kdTree.idx.Clone();
83      kdTree.kcur = original.kdTree.kcur;
84      kdTree.kneeded = original.kdTree.kneeded;
85      kdTree.n = original.kdTree.n;
86      kdTree.nodes = (int[])original.kdTree.nodes.Clone();
87      kdTree.normtype = original.kdTree.normtype;
88      kdTree.nx = original.kdTree.nx;
89      kdTree.ny = original.kdTree.ny;
90      kdTree.r = (double[])original.kdTree.r.Clone();
91      kdTree.rneeded = original.kdTree.rneeded;
92      kdTree.selfmatch = original.kdTree.selfmatch;
93      kdTree.splits = (double[])original.kdTree.splits.Clone();
94      kdTree.tags = (int[])original.kdTree.tags.Clone();
95      kdTree.x = (double[])original.kdTree.x.Clone();
96      kdTree.xy = (double[,])original.kdTree.xy.Clone();
97
98      k = original.k;
[14235]99      isCompatibilityLoaded = original.IsCompatibilityLoaded;
100      if (!IsCompatibilityLoaded) {
101        weights = new double[original.weights.Length];
102        Array.Copy(original.weights, weights, weights.Length);
103        offsets = new double[original.offsets.Length];
104        Array.Copy(original.offsets, this.offsets, this.offsets.Length);
105      }
[6583]106      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
107      if (original.classValues != null)
108        this.classValues = (double[])original.classValues.Clone();
109    }
[14235]110    public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<double> weights = null, double[] classValues = null)
[13941]111      : base(targetVariable) {
[8467]112      Name = ItemName;
113      Description = ItemDescription;
[6583]114      this.k = k;
115      this.allowedInputVariables = allowedInputVariables.ToArray();
[14235]116      double[,] inputMatrix;
117      if (IsCompatibilityLoaded) {
118        // no scaling
119        inputMatrix = AlglibUtil.PrepareInputMatrix(dataset,
120          this.allowedInputVariables.Concat(new string[] { targetVariable }),
121          rows);
122      } else {
123        this.offsets = this.allowedInputVariables
124          .Select(name => dataset.GetDoubleValues(name, rows).Average() * -1)
125          .Concat(new double[] { 0 }) // no offset for target variable
126          .ToArray();
127        if (weights == null) {
128          // automatic determination of weights (all features should have variance = 1)
129          this.weights = this.allowedInputVariables
130            .Select(name => 1.0 / dataset.GetDoubleValues(name, rows).StandardDeviationPop())
131            .Concat(new double[] { 1.0 }) // no scaling for target variable
132            .ToArray();
133        } else {
134          // user specified weights (+ 1 for target)
135          this.weights = weights.Concat(new double[] { 1.0 }).ToArray();
136          if (this.weights.Length - 1 != this.allowedInputVariables.Length)
137            throw new ArgumentException("The number of elements in the weight vector must match the number of input variables");
138        }
139        inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights);
140      }
[8465]141
142      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
143        throw new NotSupportedException(
144          "Nearest neighbour classification does not support NaN or infinity values in the input dataset.");
145
146      this.kdTree = new alglib.nearestneighbor.kdtree();
147
148      var nRows = inputMatrix.GetLength(0);
149      var nFeatures = inputMatrix.GetLength(1) - 1;
150
151      if (classValues != null) {
[6583]152        this.classValues = (double[])classValues.Clone();
[8465]153        int nClasses = classValues.Length;
154        // map original class values to values [0..nClasses-1]
155        var classIndices = new Dictionary<double, double>();
156        for (int i = 0; i < nClasses; i++)
157          classIndices[classValues[i]] = i;
158
159        for (int row = 0; row < nRows; row++) {
160          inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
161        }
162      }
163      alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdTree);
[6583]164    }
165
[14235]166    private static double[,] CreateScaledData(IDataset dataset, IEnumerable<string> variables, IEnumerable<int> rows, double[] offsets, double[] factors) {
167      var x = new double[rows.Count(), variables.Count()];
168      var colIdx = 0;
169      foreach (var variableName in variables) {
170        var rowIdx = 0;
171        foreach (var val in dataset.GetDoubleValues(variableName, rows)) {
172          x[rowIdx, colIdx] = (val + offsets[colIdx]) * factors[colIdx];
173          rowIdx++;
174        }
175        colIdx++;
176      }
177      return x;
178    }
179
[6583]180    public override IDeepCloneable Clone(Cloner cloner) {
181      return new NearestNeighbourModel(this, cloner);
182    }
183
[12509]184    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
[14235]185      double[,] inputData;
186      if (IsCompatibilityLoaded) {
187        inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
188      } else {
189        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
190      }
[6583]191
192      int n = inputData.GetLength(0);
193      int columns = inputData.GetLength(1);
194      double[] x = new double[columns];
195      double[] dists = new double[k];
196      double[,] neighbours = new double[k, columns + 1];
197
198      for (int row = 0; row < n; row++) {
199        for (int column = 0; column < columns; column++) {
200          x[column] = inputData[row, column];
201        }
202        int actNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);
203        alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
[14235]204        alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours); // gkronber: this call changes the kdTree data structure
[6583]205
206        double distanceWeightedValue = 0.0;
207        double distsSum = 0.0;
208        for (int i = 0; i < actNeighbours; i++) {
209          distanceWeightedValue += neighbours[i, columns] / dists[i];
210          distsSum += 1.0 / dists[i];
211        }
212        yield return distanceWeightedValue / distsSum;
213      }
214    }
215
[13941]216    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
[8465]217      if (classValues == null) throw new InvalidOperationException("No class values are defined.");
[14235]218      double[,] inputData;
219      if (IsCompatibilityLoaded) {
220        inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
221      } else {
222        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
223      }
[6583]224      int n = inputData.GetLength(0);
225      int columns = inputData.GetLength(1);
226      double[] x = new double[columns];
227      int[] y = new int[classValues.Length];
228      double[] dists = new double[k];
229      double[,] neighbours = new double[k, columns + 1];
230
231      for (int row = 0; row < n; row++) {
232        for (int column = 0; column < columns; column++) {
233          x[column] = inputData[row, column];
234        }
235        int actNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);
236        alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
237        alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
238
239        Array.Clear(y, 0, y.Length);
240        for (int i = 0; i < actNeighbours; i++) {
241          int classValue = (int)Math.Round(neighbours[i, columns]);
242          y[classValue]++;
243        }
244
245        // find class for with the largest probability value
246        int maxProbClassIndex = 0;
247        double maxProb = y[0];
248        for (int i = 1; i < y.Length; i++) {
249          if (maxProb < y[i]) {
250            maxProb = y[i];
251            maxProbClassIndex = i;
252          }
253        }
254        yield return classValues[maxProbClassIndex];
255      }
256    }
257
[13941]258
[6603]259    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
[13941]260      return new NearestNeighbourRegressionSolution(this, new RegressionProblemData(problemData));
[6603]261    }
[13941]262    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
263      return new NearestNeighbourClassificationSolution(this, new ClassificationProblemData(problemData));
[6604]264    }
[6603]265
[6583]266    #region events
267    public event EventHandler Changed;
268    private void OnChanged(EventArgs e) {
269      var handlers = Changed;
270      if (handlers != null)
271        handlers(this, e);
272    }
273    #endregion
274
[14235]275
276    // BackwardsCompatibility3.3
277    #region Backwards compatible code, remove with 3.4
278
279    private bool isCompatibilityLoaded = false; // new kNN models have the value false, kNN models loaded from disc have the value true
280    [Storable(DefaultValue = true)]
281    public bool IsCompatibilityLoaded {
282      get { return isCompatibilityLoaded; }
283      set { isCompatibilityLoaded = value; }
284    }
285    #endregion
[6583]286    #region persistence
[6584]287    [Storable]
288    public double KDTreeApproxF {
289      get { return kdTree.approxf; }
290      set { kdTree.approxf = value; }
291    }
292    [Storable]
293    public double[] KDTreeBoxMax {
294      get { return kdTree.boxmax; }
295      set { kdTree.boxmax = value; }
296    }
297    [Storable]
298    public double[] KDTreeBoxMin {
299      get { return kdTree.boxmin; }
300      set { kdTree.boxmin = value; }
301    }
302    [Storable]
303    public double[] KDTreeBuf {
304      get { return kdTree.buf; }
305      set { kdTree.buf = value; }
306    }
307    [Storable]
308    public double[] KDTreeCurBoxMax {
309      get { return kdTree.curboxmax; }
310      set { kdTree.curboxmax = value; }
311    }
312    [Storable]
313    public double[] KDTreeCurBoxMin {
314      get { return kdTree.curboxmin; }
315      set { kdTree.curboxmin = value; }
316    }
317    [Storable]
318    public double KDTreeCurDist {
319      get { return kdTree.curdist; }
320      set { kdTree.curdist = value; }
321    }
322    [Storable]
323    public int KDTreeDebugCounter {
324      get { return kdTree.debugcounter; }
325      set { kdTree.debugcounter = value; }
326    }
327    [Storable]
328    public int[] KDTreeIdx {
329      get { return kdTree.idx; }
330      set { kdTree.idx = value; }
331    }
332    [Storable]
333    public int KDTreeKCur {
334      get { return kdTree.kcur; }
335      set { kdTree.kcur = value; }
336    }
337    [Storable]
338    public int KDTreeKNeeded {
339      get { return kdTree.kneeded; }
340      set { kdTree.kneeded = value; }
341    }
342    [Storable]
343    public int KDTreeN {
344      get { return kdTree.n; }
345      set { kdTree.n = value; }
346    }
347    [Storable]
348    public int[] KDTreeNodes {
349      get { return kdTree.nodes; }
350      set { kdTree.nodes = value; }
351    }
352    [Storable]
353    public int KDTreeNormType {
354      get { return kdTree.normtype; }
355      set { kdTree.normtype = value; }
356    }
357    [Storable]
358    public int KDTreeNX {
359      get { return kdTree.nx; }
360      set { kdTree.nx = value; }
361    }
362    [Storable]
363    public int KDTreeNY {
364      get { return kdTree.ny; }
365      set { kdTree.ny = value; }
366    }
367    [Storable]
368    public double[] KDTreeR {
369      get { return kdTree.r; }
370      set { kdTree.r = value; }
371    }
372    [Storable]
373    public double KDTreeRNeeded {
374      get { return kdTree.rneeded; }
375      set { kdTree.rneeded = value; }
376    }
377    [Storable]
378    public bool KDTreeSelfMatch {
379      get { return kdTree.selfmatch; }
380      set { kdTree.selfmatch = value; }
381    }
382    [Storable]
383    public double[] KDTreeSplits {
384      get { return kdTree.splits; }
385      set { kdTree.splits = value; }
386    }
387    [Storable]
388    public int[] KDTreeTags {
389      get { return kdTree.tags; }
390      set { kdTree.tags = value; }
391    }
392    [Storable]
393    public double[] KDTreeX {
394      get { return kdTree.x; }
395      set { kdTree.x = value; }
396    }
397    [Storable]
398    public double[,] KDTreeXY {
399      get { return kdTree.xy; }
400      set { kdTree.xy = value; }
401    }
[6583]402    #endregion
403  }
404}
Note: See TracBrowser for help on using the repository browser.