Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2942_KNNRegressionClassification/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs @ 16488

Last change on this file since 16488 was 16488, checked in by gkronber, 5 years ago

#2942: added SelfMatch parameters in the AfterDeserialization hook (for loading files stored with the old version)

File size: 16.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
32  /// Represents a nearest neighbour model for regression and classification
33  /// </summary>
34  [StorableClass]
35  [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")]
36  public sealed class NearestNeighbourModel : ClassificationModel, INearestNeighbourModel {
37
38    private readonly object kdTreeLockObject = new object();
39
40    private alglib.nearestneighbor.kdtree kdTree;
41    public alglib.nearestneighbor.kdtree KDTree {
42      get { return kdTree; }
43      set {
44        if (value != kdTree) {
45          if (value == null) throw new ArgumentNullException();
46          kdTree = value;
47          OnChanged(EventArgs.Empty);
48        }
49      }
50    }
51
52    public override IEnumerable<string> VariablesUsedForPrediction {
53      get { return allowedInputVariables; }
54    }
55
56    [Storable]
57    private string[] allowedInputVariables;
58    [Storable]
59    private double[] classValues;
60    [Storable]
61    private int k;
62    [Storable(DefaultValue = false)]
63    private bool selfMatch;
64    [Storable(DefaultValue = null)]
65    private double[] weights; // not set for old versions loaded from disk
66    [Storable(DefaultValue = null)]
67    private double[] offsets; // not set for old versions loaded from disk
68
69    [StorableConstructor]
70    private NearestNeighbourModel(bool deserializing)
71      : base(deserializing) {
72      if (deserializing)
73        kdTree = new alglib.nearestneighbor.kdtree();
74    }
75    private NearestNeighbourModel(NearestNeighbourModel original, Cloner cloner)
76      : base(original, cloner) {
77      kdTree = new alglib.nearestneighbor.kdtree();
78      kdTree.approxf = original.kdTree.approxf;
79      kdTree.boxmax = (double[])original.kdTree.boxmax.Clone();
80      kdTree.boxmin = (double[])original.kdTree.boxmin.Clone();
81      kdTree.buf = (double[])original.kdTree.buf.Clone();
82      kdTree.curboxmax = (double[])original.kdTree.curboxmax.Clone();
83      kdTree.curboxmin = (double[])original.kdTree.curboxmin.Clone();
84      kdTree.curdist = original.kdTree.curdist;
85      kdTree.debugcounter = original.kdTree.debugcounter;
86      kdTree.idx = (int[])original.kdTree.idx.Clone();
87      kdTree.kcur = original.kdTree.kcur;
88      kdTree.kneeded = original.kdTree.kneeded;
89      kdTree.n = original.kdTree.n;
90      kdTree.nodes = (int[])original.kdTree.nodes.Clone();
91      kdTree.normtype = original.kdTree.normtype;
92      kdTree.nx = original.kdTree.nx;
93      kdTree.ny = original.kdTree.ny;
94      kdTree.r = (double[])original.kdTree.r.Clone();
95      kdTree.rneeded = original.kdTree.rneeded;
96      kdTree.selfmatch = original.kdTree.selfmatch;
97      kdTree.splits = (double[])original.kdTree.splits.Clone();
98      kdTree.tags = (int[])original.kdTree.tags.Clone();
99      kdTree.x = (double[])original.kdTree.x.Clone();
100      kdTree.xy = (double[,])original.kdTree.xy.Clone();
101      selfMatch = original.selfMatch;
102      k = original.k;
103      isCompatibilityLoaded = original.IsCompatibilityLoaded;
104      if (!IsCompatibilityLoaded) {
105        weights = new double[original.weights.Length];
106        Array.Copy(original.weights, weights, weights.Length);
107        offsets = new double[original.offsets.Length];
108        Array.Copy(original.offsets, this.offsets, this.offsets.Length);
109      }
110      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
111      if (original.classValues != null)
112        this.classValues = (double[])original.classValues.Clone();
113    }
114    public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, bool selfMatch, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<double> weights = null, double[] classValues = null)
115      : base(targetVariable) {
116      Name = ItemName;
117      Description = ItemDescription;
118      this.selfMatch = selfMatch;
119      this.k = k;
120      this.allowedInputVariables = allowedInputVariables.ToArray();
121      double[,] inputMatrix;
122      if (IsCompatibilityLoaded) {
123        // no scaling
124        inputMatrix = dataset.ToArray(
125          this.allowedInputVariables.Concat(new string[] { targetVariable }),
126          rows);
127      } else {
128        this.offsets = this.allowedInputVariables
129          .Select(name => dataset.GetDoubleValues(name, rows).Average() * -1)
130          .Concat(new double[] { 0 }) // no offset for target variable
131          .ToArray();
132        if (weights == null) {
133          // automatic determination of weights (all features should have variance = 1)
134          this.weights = this.allowedInputVariables
135            .Select(name => {
136              var pop = dataset.GetDoubleValues(name, rows).StandardDeviationPop();
137              return pop.IsAlmost(0) ? 1.0 : 1.0 / pop;
138            })
139            .Concat(new double[] { 1.0 }) // no scaling for target variable
140            .ToArray();
141        } else {
142          // user specified weights (+ 1 for target)
143          this.weights = weights.Concat(new double[] { 1.0 }).ToArray();
144          if (this.weights.Length - 1 != this.allowedInputVariables.Length)
145            throw new ArgumentException("The number of elements in the weight vector must match the number of input variables");
146        }
147        inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights);
148      }
149
150      if (inputMatrix.ContainsNanOrInfinity())
151        throw new NotSupportedException(
152          "Nearest neighbour model does not support NaN or infinity values in the input dataset.");
153
154      this.kdTree = new alglib.nearestneighbor.kdtree();
155
156      var nRows = inputMatrix.GetLength(0);
157      var nFeatures = inputMatrix.GetLength(1) - 1;
158
159      if (classValues != null) {
160        this.classValues = (double[])classValues.Clone();
161        int nClasses = classValues.Length;
162        // map original class values to values [0..nClasses-1]
163        var classIndices = new Dictionary<double, double>();
164        for (int i = 0; i < nClasses; i++)
165          classIndices[classValues[i]] = i;
166
167        for (int row = 0; row < nRows; row++) {
168          inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
169        }
170      }
171      alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdTree);
172    }
173
174    private static double[,] CreateScaledData(IDataset dataset, IEnumerable<string> variables, IEnumerable<int> rows, double[] offsets, double[] factors) {
175      var transforms =
176        variables.Select(
177          (_, colIdx) =>
178            new LinearTransformation(variables) { Addend = offsets[colIdx] * factors[colIdx], Multiplier = factors[colIdx] });
179      return dataset.ToArray(variables, transforms, rows);
180    }
181
182    public override IDeepCloneable Clone(Cloner cloner) {
183      return new NearestNeighbourModel(this, cloner);
184    }
185
186    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
187      double[,] inputData;
188      if (IsCompatibilityLoaded) {
189        inputData = dataset.ToArray(allowedInputVariables, rows);
190      } else {
191        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
192      }
193
194      int n = inputData.GetLength(0);
195      int columns = inputData.GetLength(1);
196      double[] x = new double[columns];
197      double[] dists = new double[k];
198      double[,] neighbours = new double[k, columns + 1];
199
200      for (int row = 0; row < n; row++) {
201        for (int column = 0; column < columns; column++) {
202          x[column] = inputData[row, column];
203        }
204        int numNeighbours;
205        lock (kdTreeLockObject) { // gkronber: the following calls change the kdTree data structure
206          numNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, selfMatch);
207          alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
208          alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
209        }
210        if (selfMatch) {
211          double minDist = dists[0] + 1;
212          for (int i = 0; i < numNeighbours; i++) {
213            if ((minDist > dists[i]) && (dists[i] != 0)) {
214              minDist = dists[i];
215            }
216          }
217          minDist /= 100.0;
218          for (int i = 0; i < numNeighbours; i++) {
219            if (dists[i] == 0) {
220              dists[i] = minDist;
221            }
222          }
223        }
224        double distanceWeightedValue = 0.0;
225        double distsSum = 0.0;
226        for (int i = 0; i < numNeighbours; i++) {
227          distanceWeightedValue += neighbours[i, columns] / dists[i];
228          distsSum += 1.0 / dists[i];
229        }
230        yield return distanceWeightedValue / distsSum;
231      }
232    }
233
234    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
235      if (classValues == null) throw new InvalidOperationException("No class values are defined.");
236      double[,] inputData;
237      if (IsCompatibilityLoaded) {
238        inputData = dataset.ToArray(allowedInputVariables, rows);
239      } else {
240        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
241      }
242      int n = inputData.GetLength(0);
243      int columns = inputData.GetLength(1);
244      double[] x = new double[columns];
245      int[] y = new int[classValues.Length];
246      double[] dists = new double[k];
247      double[,] neighbours = new double[k, columns + 1];
248
249      for (int row = 0; row < n; row++) {
250        for (int column = 0; column < columns; column++) {
251          x[column] = inputData[row, column];
252        }
253        int numNeighbours;
254        lock (kdTreeLockObject) {
255          // gkronber: the following calls change the kdTree data structure
256          numNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, selfMatch);
257          alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
258          alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
259        }
260        Array.Clear(y, 0, y.Length);
261        for (int i = 0; i < numNeighbours; i++) {
262          int classValue = (int)Math.Round(neighbours[i, columns]);
263          y[classValue]++;
264        }
265
266        // find class for with the largest probability value
267        int maxProbClassIndex = 0;
268        double maxProb = y[0];
269        for (int i = 1; i < y.Length; i++) {
270          if (maxProb < y[i]) {
271            maxProb = y[i];
272            maxProbClassIndex = i;
273          }
274        }
275        yield return classValues[maxProbClassIndex];
276      }
277    }
278
279
280    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
281      return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
282    }
283
284    public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
285      if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
286
287      var regressionProblemData = problemData as IRegressionProblemData;
288      if (regressionProblemData != null)
289        return IsProblemDataCompatible(regressionProblemData, out errorMessage);
290
291      var classificationProblemData = problemData as IClassificationProblemData;
292      if (classificationProblemData != null)
293        return IsProblemDataCompatible(classificationProblemData, out errorMessage);
294
295      throw new ArgumentException("The problem data is not a regression nor a classification problem data. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
296    }
297
298    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
299      return new NearestNeighbourRegressionSolution(this, new RegressionProblemData(problemData));
300    }
301    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
302      return new NearestNeighbourClassificationSolution(this, new ClassificationProblemData(problemData));
303    }
304
305    #region events
306    public event EventHandler Changed;
307    private void OnChanged(EventArgs e) {
308      var handlers = Changed;
309      if (handlers != null)
310        handlers(this, e);
311    }
312    #endregion
313
314
315    // BackwardsCompatibility3.3
316    #region Backwards compatible code, remove with 3.4
317
318    private bool isCompatibilityLoaded = false; // new kNN models have the value false, kNN models loaded from disc have the value true
319    [Storable(DefaultValue = true)]
320    public bool IsCompatibilityLoaded {
321      get { return isCompatibilityLoaded; }
322      set { isCompatibilityLoaded = value; }
323    }
324    #endregion
325    #region persistence
326    [Storable]
327    public double KDTreeApproxF {
328      get { return kdTree.approxf; }
329      set { kdTree.approxf = value; }
330    }
331    [Storable]
332    public double[] KDTreeBoxMax {
333      get { return kdTree.boxmax; }
334      set { kdTree.boxmax = value; }
335    }
336    [Storable]
337    public double[] KDTreeBoxMin {
338      get { return kdTree.boxmin; }
339      set { kdTree.boxmin = value; }
340    }
341    [Storable]
342    public double[] KDTreeBuf {
343      get { return kdTree.buf; }
344      set { kdTree.buf = value; }
345    }
346    [Storable]
347    public double[] KDTreeCurBoxMax {
348      get { return kdTree.curboxmax; }
349      set { kdTree.curboxmax = value; }
350    }
351    [Storable]
352    public double[] KDTreeCurBoxMin {
353      get { return kdTree.curboxmin; }
354      set { kdTree.curboxmin = value; }
355    }
356    [Storable]
357    public double KDTreeCurDist {
358      get { return kdTree.curdist; }
359      set { kdTree.curdist = value; }
360    }
361    [Storable]
362    public int KDTreeDebugCounter {
363      get { return kdTree.debugcounter; }
364      set { kdTree.debugcounter = value; }
365    }
366    [Storable]
367    public int[] KDTreeIdx {
368      get { return kdTree.idx; }
369      set { kdTree.idx = value; }
370    }
371    [Storable]
372    public int KDTreeKCur {
373      get { return kdTree.kcur; }
374      set { kdTree.kcur = value; }
375    }
376    [Storable]
377    public int KDTreeKNeeded {
378      get { return kdTree.kneeded; }
379      set { kdTree.kneeded = value; }
380    }
381    [Storable]
382    public int KDTreeN {
383      get { return kdTree.n; }
384      set { kdTree.n = value; }
385    }
386    [Storable]
387    public int[] KDTreeNodes {
388      get { return kdTree.nodes; }
389      set { kdTree.nodes = value; }
390    }
391    [Storable]
392    public int KDTreeNormType {
393      get { return kdTree.normtype; }
394      set { kdTree.normtype = value; }
395    }
396    [Storable]
397    public int KDTreeNX {
398      get { return kdTree.nx; }
399      set { kdTree.nx = value; }
400    }
401    [Storable]
402    public int KDTreeNY {
403      get { return kdTree.ny; }
404      set { kdTree.ny = value; }
405    }
406    [Storable]
407    public double[] KDTreeR {
408      get { return kdTree.r; }
409      set { kdTree.r = value; }
410    }
411    [Storable]
412    public double KDTreeRNeeded {
413      get { return kdTree.rneeded; }
414      set { kdTree.rneeded = value; }
415    }
416    [Storable]
417    public bool KDTreeSelfMatch {
418      get { return kdTree.selfmatch; }
419      set { kdTree.selfmatch = value; }
420    }
421    [Storable]
422    public double[] KDTreeSplits {
423      get { return kdTree.splits; }
424      set { kdTree.splits = value; }
425    }
426    [Storable]
427    public int[] KDTreeTags {
428      get { return kdTree.tags; }
429      set { kdTree.tags = value; }
430    }
431    [Storable]
432    public double[] KDTreeX {
433      get { return kdTree.x; }
434      set { kdTree.x = value; }
435    }
436    [Storable]
437    public double[,] KDTreeXY {
438      get { return kdTree.xy; }
439      set { kdTree.xy = value; }
440    }
441    #endregion
442  }
443}
Note: See TracBrowser for help on using the repository browser.