Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2942_KNNRegressionClassification/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs @ 16408

Last change on this file since 16408 was 16408, checked in by msemenki, 5 years ago

#2942: Add for KNN-Regression/Classification ability to utilize data points with zero distance to the query point. Alteration in the way weights are assigned to neighboring points (to except division-by-zero).

File size: 16.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
32  /// Represents a nearest neighbour model for regression and classification
33  /// </summary>
34  [StorableClass]
35  [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")]
36  public sealed class NearestNeighbourModel : ClassificationModel, INearestNeighbourModel {
37
38    private readonly object kdTreeLockObject = new object();
39
40    private alglib.nearestneighbor.kdtree kdTree;
41    public alglib.nearestneighbor.kdtree KDTree {
42      get { return kdTree; }
43      set {
44        if (value != kdTree) {
45          if (value == null) throw new ArgumentNullException();
46          kdTree = value;
47          OnChanged(EventArgs.Empty);
48        }
49      }
50    }
51
52
53    public override IEnumerable<string> VariablesUsedForPrediction {
54      get { return allowedInputVariables; }
55    }
56
57    [Storable]
58    private string[] allowedInputVariables;
59    [Storable]
60    private double[] classValues;
61    [Storable]
62    private int k;
63    [Storable]
64    private bool selfMatch;
65    [Storable(DefaultValue = null)]
66    private double[] weights; // not set for old versions loaded from disk
67    [Storable(DefaultValue = null)]
68    private double[] offsets; // not set for old versions loaded from disk
69
70    [StorableConstructor]
71    private NearestNeighbourModel(bool deserializing)
72      : base(deserializing) {
73      if (deserializing)
74        kdTree = new alglib.nearestneighbor.kdtree();
75    }
76    private NearestNeighbourModel(NearestNeighbourModel original, Cloner cloner)
77      : base(original, cloner) {
78      kdTree = new alglib.nearestneighbor.kdtree();
79      kdTree.approxf = original.kdTree.approxf;
80      kdTree.boxmax = (double[])original.kdTree.boxmax.Clone();
81      kdTree.boxmin = (double[])original.kdTree.boxmin.Clone();
82      kdTree.buf = (double[])original.kdTree.buf.Clone();
83      kdTree.curboxmax = (double[])original.kdTree.curboxmax.Clone();
84      kdTree.curboxmin = (double[])original.kdTree.curboxmin.Clone();
85      kdTree.curdist = original.kdTree.curdist;
86      kdTree.debugcounter = original.kdTree.debugcounter;
87      kdTree.idx = (int[])original.kdTree.idx.Clone();
88      kdTree.kcur = original.kdTree.kcur;
89      kdTree.kneeded = original.kdTree.kneeded;
90      kdTree.n = original.kdTree.n;
91      kdTree.nodes = (int[])original.kdTree.nodes.Clone();
92      kdTree.normtype = original.kdTree.normtype;
93      kdTree.nx = original.kdTree.nx;
94      kdTree.ny = original.kdTree.ny;
95      kdTree.r = (double[])original.kdTree.r.Clone();
96      kdTree.rneeded = original.kdTree.rneeded;
97      kdTree.selfmatch = original.kdTree.selfmatch;
98      kdTree.splits = (double[])original.kdTree.splits.Clone();
99      kdTree.tags = (int[])original.kdTree.tags.Clone();
100      kdTree.x = (double[])original.kdTree.x.Clone();
101      kdTree.xy = (double[,])original.kdTree.xy.Clone();
102      selfMatch = original.selfMatch;
103      k = original.k;
104      isCompatibilityLoaded = original.IsCompatibilityLoaded;
105      if (!IsCompatibilityLoaded) {
106        weights = new double[original.weights.Length];
107        Array.Copy(original.weights, weights, weights.Length);
108        offsets = new double[original.offsets.Length];
109        Array.Copy(original.offsets, this.offsets, this.offsets.Length);
110      }
111      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
112      if (original.classValues != null)
113        this.classValues = (double[])original.classValues.Clone();
114    }
115    public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, bool selfMatch, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<double> weights = null, double[] classValues = null)
116      : base(targetVariable) {
117      Name = ItemName;
118      Description = ItemDescription;
119      this.selfMatch = selfMatch;
120      this.k = k;
121      this.allowedInputVariables = allowedInputVariables.ToArray();
122      double[,] inputMatrix;
123      if (IsCompatibilityLoaded) {
124        // no scaling
125        inputMatrix = dataset.ToArray(
126          this.allowedInputVariables.Concat(new string[] { targetVariable }),
127          rows);
128      } else {
129        this.offsets = this.allowedInputVariables
130          .Select(name => dataset.GetDoubleValues(name, rows).Average() * -1)
131          .Concat(new double[] { 0 }) // no offset for target variable
132          .ToArray();
133        if (weights == null) {
134          // automatic determination of weights (all features should have variance = 1)
135          this.weights = this.allowedInputVariables
136            .Select(name => {
137              var pop = dataset.GetDoubleValues(name, rows).StandardDeviationPop();
138              return pop.IsAlmost(0) ? 1.0 : 1.0 / pop;
139            })
140            .Concat(new double[] { 1.0 }) // no scaling for target variable
141            .ToArray();
142        } else {
143          // user specified weights (+ 1 for target)
144          this.weights = weights.Concat(new double[] { 1.0 }).ToArray();
145          if (this.weights.Length - 1 != this.allowedInputVariables.Length)
146            throw new ArgumentException("The number of elements in the weight vector must match the number of input variables");
147        }
148        inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights);
149      }
150
151      if (inputMatrix.ContainsNanOrInfinity())
152        throw new NotSupportedException(
153          "Nearest neighbour model does not support NaN or infinity values in the input dataset.");
154
155      this.kdTree = new alglib.nearestneighbor.kdtree();
156
157      var nRows = inputMatrix.GetLength(0);
158      var nFeatures = inputMatrix.GetLength(1) - 1;
159
160      if (classValues != null) {
161        this.classValues = (double[])classValues.Clone();
162        int nClasses = classValues.Length;
163        // map original class values to values [0..nClasses-1]
164        var classIndices = new Dictionary<double, double>();
165        for (int i = 0; i < nClasses; i++)
166          classIndices[classValues[i]] = i;
167
168        for (int row = 0; row < nRows; row++) {
169          inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
170        }
171      }
172      alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdTree);
173    }
174
175    private static double[,] CreateScaledData(IDataset dataset, IEnumerable<string> variables, IEnumerable<int> rows, double[] offsets, double[] factors) {
176      var transforms =
177        variables.Select(
178          (_, colIdx) =>
179            new LinearTransformation(variables) { Addend = offsets[colIdx] * factors[colIdx], Multiplier = factors[colIdx] });
180      return dataset.ToArray(variables, transforms, rows);
181    }
182
183    public override IDeepCloneable Clone(Cloner cloner) {
184      return new NearestNeighbourModel(this, cloner);
185    }
186
187    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
188      double[,] inputData;
189      if (IsCompatibilityLoaded) {
190        inputData = dataset.ToArray(allowedInputVariables, rows);
191      } else {
192        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
193      }
194
195      int n = inputData.GetLength(0);
196      int columns = inputData.GetLength(1);
197      double[] x = new double[columns];
198      double[] dists = new double[k];
199      double[,] neighbours = new double[k, columns + 1];
200
201      for (int row = 0; row < n; row++) {
202        for (int column = 0; column < columns; column++) {
203          x[column] = inputData[row, column];
204        }
205        int numNeighbours;
206        lock (kdTreeLockObject) { // gkronber: the following calls change the kdTree data structure
207          numNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, selfMatch);
208          alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
209          alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
210        }
211        if (selfMatch) {
212          double minDist = dists[0] + 1;
213          for (int i = 0; i < numNeighbours; i++) {
214            if ((minDist > dists[i]) && (dists[i] != 0)) {
215              minDist = dists[i];
216            }
217          }
218          minDist /= 100.0;
219          for (int i = 0; i < numNeighbours; i++) {
220            if (dists[i] == 0) {
221              dists[i] = minDist;
222            }
223          }
224        }
225        double distanceWeightedValue = 0.0;
226        double distsSum = 0.0;
227        for (int i = 0; i < numNeighbours; i++) {
228          distanceWeightedValue += neighbours[i, columns] / dists[i];
229          distsSum += 1.0 / dists[i];
230        }
231        yield return distanceWeightedValue / distsSum;
232      }
233    }
234
235    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
236      if (classValues == null) throw new InvalidOperationException("No class values are defined.");
237      double[,] inputData;
238      if (IsCompatibilityLoaded) {
239        inputData = dataset.ToArray(allowedInputVariables, rows);
240      } else {
241        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
242      }
243      int n = inputData.GetLength(0);
244      int columns = inputData.GetLength(1);
245      double[] x = new double[columns];
246      int[] y = new int[classValues.Length];
247      double[] dists = new double[k];
248      double[,] neighbours = new double[k, columns + 1];
249
250      for (int row = 0; row < n; row++) {
251        for (int column = 0; column < columns; column++) {
252          x[column] = inputData[row, column];
253        }
254        int numNeighbours;
255        lock (kdTreeLockObject) {
256          // gkronber: the following calls change the kdTree data structure
257          numNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, selfMatch);
258          alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
259          alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
260        }
261        Array.Clear(y, 0, y.Length);
262        for (int i = 0; i < numNeighbours; i++) {
263          int classValue = (int)Math.Round(neighbours[i, columns]);
264          y[classValue]++;
265        }
266
267        // find class for with the largest probability value
268        int maxProbClassIndex = 0;
269        double maxProb = y[0];
270        for (int i = 1; i < y.Length; i++) {
271          if (maxProb < y[i]) {
272            maxProb = y[i];
273            maxProbClassIndex = i;
274          }
275        }
276        yield return classValues[maxProbClassIndex];
277      }
278    }
279
280
281    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
282      return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
283    }
284
285    public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
286      if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
287
288      var regressionProblemData = problemData as IRegressionProblemData;
289      if (regressionProblemData != null)
290        return IsProblemDataCompatible(regressionProblemData, out errorMessage);
291
292      var classificationProblemData = problemData as IClassificationProblemData;
293      if (classificationProblemData != null)
294        return IsProblemDataCompatible(classificationProblemData, out errorMessage);
295
296      throw new ArgumentException("The problem data is not a regression nor a classification problem data. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
297    }
298
299    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
300      return new NearestNeighbourRegressionSolution(this, new RegressionProblemData(problemData));
301    }
302    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
303      return new NearestNeighbourClassificationSolution(this, new ClassificationProblemData(problemData));
304    }
305
306    #region events
307    public event EventHandler Changed;
308    private void OnChanged(EventArgs e) {
309      var handlers = Changed;
310      if (handlers != null)
311        handlers(this, e);
312    }
313    #endregion
314
315
316    // BackwardsCompatibility3.3
317    #region Backwards compatible code, remove with 3.4
318
319    private bool isCompatibilityLoaded = false; // new kNN models have the value false, kNN models loaded from disc have the value true
320    [Storable(DefaultValue = true)]
321    public bool IsCompatibilityLoaded {
322      get { return isCompatibilityLoaded; }
323      set { isCompatibilityLoaded = value; }
324    }
325    #endregion
326    #region persistence
327    [Storable]
328    public double KDTreeApproxF {
329      get { return kdTree.approxf; }
330      set { kdTree.approxf = value; }
331    }
332    [Storable]
333    public double[] KDTreeBoxMax {
334      get { return kdTree.boxmax; }
335      set { kdTree.boxmax = value; }
336    }
337    [Storable]
338    public double[] KDTreeBoxMin {
339      get { return kdTree.boxmin; }
340      set { kdTree.boxmin = value; }
341    }
342    [Storable]
343    public double[] KDTreeBuf {
344      get { return kdTree.buf; }
345      set { kdTree.buf = value; }
346    }
347    [Storable]
348    public double[] KDTreeCurBoxMax {
349      get { return kdTree.curboxmax; }
350      set { kdTree.curboxmax = value; }
351    }
352    [Storable]
353    public double[] KDTreeCurBoxMin {
354      get { return kdTree.curboxmin; }
355      set { kdTree.curboxmin = value; }
356    }
357    [Storable]
358    public double KDTreeCurDist {
359      get { return kdTree.curdist; }
360      set { kdTree.curdist = value; }
361    }
362    [Storable]
363    public int KDTreeDebugCounter {
364      get { return kdTree.debugcounter; }
365      set { kdTree.debugcounter = value; }
366    }
367    [Storable]
368    public int[] KDTreeIdx {
369      get { return kdTree.idx; }
370      set { kdTree.idx = value; }
371    }
372    [Storable]
373    public int KDTreeKCur {
374      get { return kdTree.kcur; }
375      set { kdTree.kcur = value; }
376    }
377    [Storable]
378    public int KDTreeKNeeded {
379      get { return kdTree.kneeded; }
380      set { kdTree.kneeded = value; }
381    }
382    [Storable]
383    public int KDTreeN {
384      get { return kdTree.n; }
385      set { kdTree.n = value; }
386    }
387    [Storable]
388    public int[] KDTreeNodes {
389      get { return kdTree.nodes; }
390      set { kdTree.nodes = value; }
391    }
392    [Storable]
393    public int KDTreeNormType {
394      get { return kdTree.normtype; }
395      set { kdTree.normtype = value; }
396    }
397    [Storable]
398    public int KDTreeNX {
399      get { return kdTree.nx; }
400      set { kdTree.nx = value; }
401    }
402    [Storable]
403    public int KDTreeNY {
404      get { return kdTree.ny; }
405      set { kdTree.ny = value; }
406    }
407    [Storable]
408    public double[] KDTreeR {
409      get { return kdTree.r; }
410      set { kdTree.r = value; }
411    }
412    [Storable]
413    public double KDTreeRNeeded {
414      get { return kdTree.rneeded; }
415      set { kdTree.rneeded = value; }
416    }
417    [Storable]
418    public bool KDTreeSelfMatch {
419      get { return kdTree.selfmatch; }
420      set { kdTree.selfmatch = value; }
421    }
422    [Storable]
423    public double[] KDTreeSplits {
424      get { return kdTree.splits; }
425      set { kdTree.splits = value; }
426    }
427    [Storable]
428    public int[] KDTreeTags {
429      get { return kdTree.tags; }
430      set { kdTree.tags = value; }
431    }
432    [Storable]
433    public double[] KDTreeX {
434      get { return kdTree.x; }
435      set { kdTree.x = value; }
436    }
437    [Storable]
438    public double[,] KDTreeXY {
439      get { return kdTree.xy; }
440      set { kdTree.xy = value; }
441    }
442    #endregion
443  }
444}
Note: See TracBrowser for help on using the repository browser.