source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs @ 17097

Last change on this file since 17097 was 17097, checked in by mkommend, 2 months ago

#2520: Merged 16565 - 16579 into stable.

File size: 16.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HEAL.Attic;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
32  /// Represents a nearest neighbour model for regression and classification
33  /// </summary>
34  [StorableType("A76C0823-3077-4ACE-8A40-E9B717C7DB60")]
35  [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")]
36  public sealed class NearestNeighbourModel : ClassificationModel, INearestNeighbourModel {
37
38    private readonly object kdTreeLockObject = new object();
39    private alglib.nearestneighbor.kdtree kdTree;
40    public alglib.nearestneighbor.kdtree KDTree {
41      get { return kdTree; }
42      set {
43        if (value != kdTree) {
44          if (value == null) throw new ArgumentNullException();
45          kdTree = value;
46          OnChanged(EventArgs.Empty);
47        }
48      }
49    }
50
51
52    public override IEnumerable<string> VariablesUsedForPrediction {
53      get { return allowedInputVariables; }
54    }
55
56    [Storable]
57    private string[] allowedInputVariables;
58    [Storable]
59    private double[] classValues;
60    [Storable]
61    private int k;
62    [Storable(DefaultValue = null)]
63    private double[] weights; // not set for old versions loaded from disk
64    [Storable(DefaultValue = null)]
65    private double[] offsets; // not set for old versions loaded from disk
66
67    [StorableConstructor]
68    private NearestNeighbourModel(StorableConstructorFlag _) : base(_) {
69      kdTree = new alglib.nearestneighbor.kdtree();
70    }
71    private NearestNeighbourModel(NearestNeighbourModel original, Cloner cloner)
72      : base(original, cloner) {
73      kdTree = new alglib.nearestneighbor.kdtree();
74      kdTree.approxf = original.kdTree.approxf;
75      kdTree.boxmax = (double[])original.kdTree.boxmax.Clone();
76      kdTree.boxmin = (double[])original.kdTree.boxmin.Clone();
77      kdTree.buf = (double[])original.kdTree.buf.Clone();
78      kdTree.curboxmax = (double[])original.kdTree.curboxmax.Clone();
79      kdTree.curboxmin = (double[])original.kdTree.curboxmin.Clone();
80      kdTree.curdist = original.kdTree.curdist;
81      kdTree.debugcounter = original.kdTree.debugcounter;
82      kdTree.idx = (int[])original.kdTree.idx.Clone();
83      kdTree.kcur = original.kdTree.kcur;
84      kdTree.kneeded = original.kdTree.kneeded;
85      kdTree.n = original.kdTree.n;
86      kdTree.nodes = (int[])original.kdTree.nodes.Clone();
87      kdTree.normtype = original.kdTree.normtype;
88      kdTree.nx = original.kdTree.nx;
89      kdTree.ny = original.kdTree.ny;
90      kdTree.r = (double[])original.kdTree.r.Clone();
91      kdTree.rneeded = original.kdTree.rneeded;
92      kdTree.selfmatch = original.kdTree.selfmatch;
93      kdTree.splits = (double[])original.kdTree.splits.Clone();
94      kdTree.tags = (int[])original.kdTree.tags.Clone();
95      kdTree.x = (double[])original.kdTree.x.Clone();
96      kdTree.xy = (double[,])original.kdTree.xy.Clone();
97
98      k = original.k;
99      isCompatibilityLoaded = original.IsCompatibilityLoaded;
100      if (!IsCompatibilityLoaded) {
101        weights = new double[original.weights.Length];
102        Array.Copy(original.weights, weights, weights.Length);
103        offsets = new double[original.offsets.Length];
104        Array.Copy(original.offsets, this.offsets, this.offsets.Length);
105      }
106      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
107      if (original.classValues != null)
108        this.classValues = (double[])original.classValues.Clone();
109    }
110    public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<double> weights = null, double[] classValues = null)
111      : base(targetVariable) {
112      Name = ItemName;
113      Description = ItemDescription;
114      this.k = k;
115      this.allowedInputVariables = allowedInputVariables.ToArray();
116      double[,] inputMatrix;
117      if (IsCompatibilityLoaded) {
118        // no scaling
119        inputMatrix = dataset.ToArray(
120          this.allowedInputVariables.Concat(new string[] { targetVariable }),
121          rows);
122      } else {
123        this.offsets = this.allowedInputVariables
124          .Select(name => dataset.GetDoubleValues(name, rows).Average() * -1)
125          .Concat(new double[] { 0 }) // no offset for target variable
126          .ToArray();
127        if (weights == null) {
128          // automatic determination of weights (all features should have variance = 1)
129          this.weights = this.allowedInputVariables
130            .Select(name => {
131              var pop = dataset.GetDoubleValues(name, rows).StandardDeviationPop();
132              return  pop.IsAlmost(0) ? 1.0 : 1.0/pop;
133            })
134            .Concat(new double[] { 1.0 }) // no scaling for target variable
135            .ToArray();
136        } else {
137          // user specified weights (+ 1 for target)
138          this.weights = weights.Concat(new double[] { 1.0 }).ToArray();
139          if (this.weights.Length - 1 != this.allowedInputVariables.Length)
140            throw new ArgumentException("The number of elements in the weight vector must match the number of input variables");
141        }
142        inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights);
143      }
144
145      if (inputMatrix.ContainsNanOrInfinity())
146        throw new NotSupportedException(
147          "Nearest neighbour model does not support NaN or infinity values in the input dataset.");
148
149      this.kdTree = new alglib.nearestneighbor.kdtree();
150
151      var nRows = inputMatrix.GetLength(0);
152      var nFeatures = inputMatrix.GetLength(1) - 1;
153
154      if (classValues != null) {
155        this.classValues = (double[])classValues.Clone();
156        int nClasses = classValues.Length;
157        // map original class values to values [0..nClasses-1]
158        var classIndices = new Dictionary<double, double>();
159        for (int i = 0; i < nClasses; i++)
160          classIndices[classValues[i]] = i;
161
162        for (int row = 0; row < nRows; row++) {
163          inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
164        }
165      }
166      alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdTree);
167    }
168
169    private static double[,] CreateScaledData(IDataset dataset, IEnumerable<string> variables, IEnumerable<int> rows, double[] offsets, double[] factors) {
170      var transforms =
171        variables.Select(
172          (_, colIdx) =>
173            new LinearTransformation(variables) { Addend = offsets[colIdx] * factors[colIdx], Multiplier = factors[colIdx] });
174      return dataset.ToArray(variables, transforms, rows);
175    }
176
177    public override IDeepCloneable Clone(Cloner cloner) {
178      return new NearestNeighbourModel(this, cloner);
179    }
180
181    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
182      double[,] inputData;
183      if (IsCompatibilityLoaded) {
184        inputData = dataset.ToArray(allowedInputVariables, rows);
185      } else {
186        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
187      }
188
189      int n = inputData.GetLength(0);
190      int columns = inputData.GetLength(1);
191      double[] x = new double[columns];
192      double[] dists = new double[k];
193      double[,] neighbours = new double[k, columns + 1];
194
195      for (int row = 0; row < n; row++) {
196        for (int column = 0; column < columns; column++) {
197          x[column] = inputData[row, column];
198        }
199        int numNeighbours;
200        lock (kdTreeLockObject) { // gkronber: the following calls change the kdTree data structure
201          numNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);
202          alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
203          alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
204        }
205
206        double distanceWeightedValue = 0.0;
207        double distsSum = 0.0;
208        for (int i = 0; i < numNeighbours; i++) {
209          distanceWeightedValue += neighbours[i, columns] / dists[i];
210          distsSum += 1.0 / dists[i];
211        }
212        yield return distanceWeightedValue / distsSum;
213      }
214    }
215
216    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
217      if (classValues == null) throw new InvalidOperationException("No class values are defined.");
218      double[,] inputData;
219      if (IsCompatibilityLoaded) {
220        inputData = dataset.ToArray(allowedInputVariables, rows);
221      } else {
222        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
223      }
224      int n = inputData.GetLength(0);
225      int columns = inputData.GetLength(1);
226      double[] x = new double[columns];
227      int[] y = new int[classValues.Length];
228      double[] dists = new double[k];
229      double[,] neighbours = new double[k, columns + 1];
230
231      for (int row = 0; row < n; row++) {
232        for (int column = 0; column < columns; column++) {
233          x[column] = inputData[row, column];
234        }
235        int numNeighbours;
236        lock (kdTreeLockObject) {
237          // gkronber: the following calls change the kdTree data structure
238          numNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);
239          alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
240          alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
241        }
242        Array.Clear(y, 0, y.Length);
243        for (int i = 0; i < numNeighbours; i++) {
244          int classValue = (int)Math.Round(neighbours[i, columns]);
245          y[classValue]++;
246        }
247
248        // find class for with the largest probability value
249        int maxProbClassIndex = 0;
250        double maxProb = y[0];
251        for (int i = 1; i < y.Length; i++) {
252          if (maxProb < y[i]) {
253            maxProb = y[i];
254            maxProbClassIndex = i;
255          }
256        }
257        yield return classValues[maxProbClassIndex];
258      }
259    }
260
261
262    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
263      return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
264    }
265
266    public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
267      if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
268
269      var regressionProblemData = problemData as IRegressionProblemData;
270      if (regressionProblemData != null)
271        return IsProblemDataCompatible(regressionProblemData, out errorMessage);
272
273      var classificationProblemData = problemData as IClassificationProblemData;
274      if (classificationProblemData != null)
275        return IsProblemDataCompatible(classificationProblemData, out errorMessage);
276
277      throw new ArgumentException("The problem data is not compatible with this nearest neighbour model. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
278    }
279
280    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
281      return new NearestNeighbourRegressionSolution(this, new RegressionProblemData(problemData));
282    }
283    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
284      return new NearestNeighbourClassificationSolution(this, new ClassificationProblemData(problemData));
285    }
286
287    #region events
288    public event EventHandler Changed;
289    private void OnChanged(EventArgs e) {
290      var handlers = Changed;
291      if (handlers != null)
292        handlers(this, e);
293    }
294    #endregion
295
296
297    // BackwardsCompatibility3.3
298    #region Backwards compatible code, remove with 3.4
299
300    private bool isCompatibilityLoaded = false; // new kNN models have the value false, kNN models loaded from disc have the value true
301    [Storable(DefaultValue = true)]
302    public bool IsCompatibilityLoaded {
303      get { return isCompatibilityLoaded; }
304      set { isCompatibilityLoaded = value; }
305    }
306    #endregion
307    #region persistence
308    [Storable]
309    public double KDTreeApproxF {
310      get { return kdTree.approxf; }
311      set { kdTree.approxf = value; }
312    }
313    [Storable]
314    public double[] KDTreeBoxMax {
315      get { return kdTree.boxmax; }
316      set { kdTree.boxmax = value; }
317    }
318    [Storable]
319    public double[] KDTreeBoxMin {
320      get { return kdTree.boxmin; }
321      set { kdTree.boxmin = value; }
322    }
323    [Storable]
324    public double[] KDTreeBuf {
325      get { return kdTree.buf; }
326      set { kdTree.buf = value; }
327    }
328    [Storable]
329    public double[] KDTreeCurBoxMax {
330      get { return kdTree.curboxmax; }
331      set { kdTree.curboxmax = value; }
332    }
333    [Storable]
334    public double[] KDTreeCurBoxMin {
335      get { return kdTree.curboxmin; }
336      set { kdTree.curboxmin = value; }
337    }
338    [Storable]
339    public double KDTreeCurDist {
340      get { return kdTree.curdist; }
341      set { kdTree.curdist = value; }
342    }
343    [Storable]
344    public int KDTreeDebugCounter {
345      get { return kdTree.debugcounter; }
346      set { kdTree.debugcounter = value; }
347    }
348    [Storable]
349    public int[] KDTreeIdx {
350      get { return kdTree.idx; }
351      set { kdTree.idx = value; }
352    }
353    [Storable]
354    public int KDTreeKCur {
355      get { return kdTree.kcur; }
356      set { kdTree.kcur = value; }
357    }
358    [Storable]
359    public int KDTreeKNeeded {
360      get { return kdTree.kneeded; }
361      set { kdTree.kneeded = value; }
362    }
363    [Storable]
364    public int KDTreeN {
365      get { return kdTree.n; }
366      set { kdTree.n = value; }
367    }
368    [Storable]
369    public int[] KDTreeNodes {
370      get { return kdTree.nodes; }
371      set { kdTree.nodes = value; }
372    }
373    [Storable]
374    public int KDTreeNormType {
375      get { return kdTree.normtype; }
376      set { kdTree.normtype = value; }
377    }
378    [Storable]
379    public int KDTreeNX {
380      get { return kdTree.nx; }
381      set { kdTree.nx = value; }
382    }
383    [Storable]
384    public int KDTreeNY {
385      get { return kdTree.ny; }
386      set { kdTree.ny = value; }
387    }
388    [Storable]
389    public double[] KDTreeR {
390      get { return kdTree.r; }
391      set { kdTree.r = value; }
392    }
393    [Storable]
394    public double KDTreeRNeeded {
395      get { return kdTree.rneeded; }
396      set { kdTree.rneeded = value; }
397    }
398    [Storable]
399    public bool KDTreeSelfMatch {
400      get { return kdTree.selfmatch; }
401      set { kdTree.selfmatch = value; }
402    }
403    [Storable]
404    public double[] KDTreeSplits {
405      get { return kdTree.splits; }
406      set { kdTree.splits = value; }
407    }
408    [Storable]
409    public int[] KDTreeTags {
410      get { return kdTree.tags; }
411      set { kdTree.tags = value; }
412    }
413    [Storable]
414    public double[] KDTreeX {
415      get { return kdTree.x; }
416      set { kdTree.x = value; }
417    }
418    [Storable]
419    public double[,] KDTreeXY {
420      get { return kdTree.xy; }
421      set { kdTree.xy = value; }
422    }
423    #endregion
424  }
425}
Note: See TracBrowser for help on using the repository browser.