Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2965_CancelablePersistence/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs @ 16683

Last change on this file since 16683 was 16243, checked in by mkommend, 6 years ago

#2955: Added IsProblemDataCompatible and IsDatasetCompatible to all DataAnalysisModels.

File size: 16.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
32  /// Represents a nearest neighbour model for regression and classification
33  /// </summary>
34  [StorableClass]
35  [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")]
36  public sealed class NearestNeighbourModel : ClassificationModel, INearestNeighbourModel {
37
38    private readonly object kdTreeLockObject = new object();
39    private alglib.nearestneighbor.kdtree kdTree;
40    public alglib.nearestneighbor.kdtree KDTree {
41      get { return kdTree; }
42      set {
43        if (value != kdTree) {
44          if (value == null) throw new ArgumentNullException();
45          kdTree = value;
46          OnChanged(EventArgs.Empty);
47        }
48      }
49    }
50
51
52    public override IEnumerable<string> VariablesUsedForPrediction {
53      get { return allowedInputVariables; }
54    }
55
56    [Storable]
57    private string[] allowedInputVariables;
58    [Storable]
59    private double[] classValues;
60    [Storable]
61    private int k;
62    [Storable(DefaultValue = null)]
63    private double[] weights; // not set for old versions loaded from disk
64    [Storable(DefaultValue = null)]
65    private double[] offsets; // not set for old versions loaded from disk
66
67    [StorableConstructor]
68    private NearestNeighbourModel(bool deserializing)
69      : base(deserializing) {
70      if (deserializing)
71        kdTree = new alglib.nearestneighbor.kdtree();
72    }
73    private NearestNeighbourModel(NearestNeighbourModel original, Cloner cloner)
74      : base(original, cloner) {
75      kdTree = new alglib.nearestneighbor.kdtree();
76      kdTree.approxf = original.kdTree.approxf;
77      kdTree.boxmax = (double[])original.kdTree.boxmax.Clone();
78      kdTree.boxmin = (double[])original.kdTree.boxmin.Clone();
79      kdTree.buf = (double[])original.kdTree.buf.Clone();
80      kdTree.curboxmax = (double[])original.kdTree.curboxmax.Clone();
81      kdTree.curboxmin = (double[])original.kdTree.curboxmin.Clone();
82      kdTree.curdist = original.kdTree.curdist;
83      kdTree.debugcounter = original.kdTree.debugcounter;
84      kdTree.idx = (int[])original.kdTree.idx.Clone();
85      kdTree.kcur = original.kdTree.kcur;
86      kdTree.kneeded = original.kdTree.kneeded;
87      kdTree.n = original.kdTree.n;
88      kdTree.nodes = (int[])original.kdTree.nodes.Clone();
89      kdTree.normtype = original.kdTree.normtype;
90      kdTree.nx = original.kdTree.nx;
91      kdTree.ny = original.kdTree.ny;
92      kdTree.r = (double[])original.kdTree.r.Clone();
93      kdTree.rneeded = original.kdTree.rneeded;
94      kdTree.selfmatch = original.kdTree.selfmatch;
95      kdTree.splits = (double[])original.kdTree.splits.Clone();
96      kdTree.tags = (int[])original.kdTree.tags.Clone();
97      kdTree.x = (double[])original.kdTree.x.Clone();
98      kdTree.xy = (double[,])original.kdTree.xy.Clone();
99
100      k = original.k;
101      isCompatibilityLoaded = original.IsCompatibilityLoaded;
102      if (!IsCompatibilityLoaded) {
103        weights = new double[original.weights.Length];
104        Array.Copy(original.weights, weights, weights.Length);
105        offsets = new double[original.offsets.Length];
106        Array.Copy(original.offsets, this.offsets, this.offsets.Length);
107      }
108      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
109      if (original.classValues != null)
110        this.classValues = (double[])original.classValues.Clone();
111    }
112    public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<double> weights = null, double[] classValues = null)
113      : base(targetVariable) {
114      Name = ItemName;
115      Description = ItemDescription;
116      this.k = k;
117      this.allowedInputVariables = allowedInputVariables.ToArray();
118      double[,] inputMatrix;
119      if (IsCompatibilityLoaded) {
120        // no scaling
121        inputMatrix = dataset.ToArray(
122          this.allowedInputVariables.Concat(new string[] { targetVariable }),
123          rows);
124      } else {
125        this.offsets = this.allowedInputVariables
126          .Select(name => dataset.GetDoubleValues(name, rows).Average() * -1)
127          .Concat(new double[] { 0 }) // no offset for target variable
128          .ToArray();
129        if (weights == null) {
130          // automatic determination of weights (all features should have variance = 1)
131          this.weights = this.allowedInputVariables
132            .Select(name => {
133              var pop = dataset.GetDoubleValues(name, rows).StandardDeviationPop();
134              return  pop.IsAlmost(0) ? 1.0 : 1.0/pop;
135            })
136            .Concat(new double[] { 1.0 }) // no scaling for target variable
137            .ToArray();
138        } else {
139          // user specified weights (+ 1 for target)
140          this.weights = weights.Concat(new double[] { 1.0 }).ToArray();
141          if (this.weights.Length - 1 != this.allowedInputVariables.Length)
142            throw new ArgumentException("The number of elements in the weight vector must match the number of input variables");
143        }
144        inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights);
145      }
146
147      if (inputMatrix.ContainsNanOrInfinity())
148        throw new NotSupportedException(
149          "Nearest neighbour model does not support NaN or infinity values in the input dataset.");
150
151      this.kdTree = new alglib.nearestneighbor.kdtree();
152
153      var nRows = inputMatrix.GetLength(0);
154      var nFeatures = inputMatrix.GetLength(1) - 1;
155
156      if (classValues != null) {
157        this.classValues = (double[])classValues.Clone();
158        int nClasses = classValues.Length;
159        // map original class values to values [0..nClasses-1]
160        var classIndices = new Dictionary<double, double>();
161        for (int i = 0; i < nClasses; i++)
162          classIndices[classValues[i]] = i;
163
164        for (int row = 0; row < nRows; row++) {
165          inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
166        }
167      }
168      alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdTree);
169    }
170
171    private static double[,] CreateScaledData(IDataset dataset, IEnumerable<string> variables, IEnumerable<int> rows, double[] offsets, double[] factors) {
172      var transforms =
173        variables.Select(
174          (_, colIdx) =>
175            new LinearTransformation(variables) { Addend = offsets[colIdx] * factors[colIdx], Multiplier = factors[colIdx] });
176      return dataset.ToArray(variables, transforms, rows);
177    }
178
179    public override IDeepCloneable Clone(Cloner cloner) {
180      return new NearestNeighbourModel(this, cloner);
181    }
182
183    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
184      double[,] inputData;
185      if (IsCompatibilityLoaded) {
186        inputData = dataset.ToArray(allowedInputVariables, rows);
187      } else {
188        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
189      }
190
191      int n = inputData.GetLength(0);
192      int columns = inputData.GetLength(1);
193      double[] x = new double[columns];
194      double[] dists = new double[k];
195      double[,] neighbours = new double[k, columns + 1];
196
197      for (int row = 0; row < n; row++) {
198        for (int column = 0; column < columns; column++) {
199          x[column] = inputData[row, column];
200        }
201        int numNeighbours;
202        lock (kdTreeLockObject) { // gkronber: the following calls change the kdTree data structure
203          numNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);
204          alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
205          alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
206        }
207
208        double distanceWeightedValue = 0.0;
209        double distsSum = 0.0;
210        for (int i = 0; i < numNeighbours; i++) {
211          distanceWeightedValue += neighbours[i, columns] / dists[i];
212          distsSum += 1.0 / dists[i];
213        }
214        yield return distanceWeightedValue / distsSum;
215      }
216    }
217
218    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
219      if (classValues == null) throw new InvalidOperationException("No class values are defined.");
220      double[,] inputData;
221      if (IsCompatibilityLoaded) {
222        inputData = dataset.ToArray(allowedInputVariables, rows);
223      } else {
224        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
225      }
226      int n = inputData.GetLength(0);
227      int columns = inputData.GetLength(1);
228      double[] x = new double[columns];
229      int[] y = new int[classValues.Length];
230      double[] dists = new double[k];
231      double[,] neighbours = new double[k, columns + 1];
232
233      for (int row = 0; row < n; row++) {
234        for (int column = 0; column < columns; column++) {
235          x[column] = inputData[row, column];
236        }
237        int numNeighbours;
238        lock (kdTreeLockObject) {
239          // gkronber: the following calls change the kdTree data structure
240          numNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);
241          alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
242          alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
243        }
244        Array.Clear(y, 0, y.Length);
245        for (int i = 0; i < numNeighbours; i++) {
246          int classValue = (int)Math.Round(neighbours[i, columns]);
247          y[classValue]++;
248        }
249
250        // find class for with the largest probability value
251        int maxProbClassIndex = 0;
252        double maxProb = y[0];
253        for (int i = 1; i < y.Length; i++) {
254          if (maxProb < y[i]) {
255            maxProb = y[i];
256            maxProbClassIndex = i;
257          }
258        }
259        yield return classValues[maxProbClassIndex];
260      }
261    }
262
263
264    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
265      return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
266    }
267
268    public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
269      if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
270
271      var regressionProblemData = problemData as IRegressionProblemData;
272      if (regressionProblemData != null)
273        return IsProblemDataCompatible(regressionProblemData, out errorMessage);
274
275      var classificationProblemData = problemData as IClassificationProblemData;
276      if (classificationProblemData != null)
277        return IsProblemDataCompatible(classificationProblemData, out errorMessage);
278
279      throw new ArgumentException("The problem data is not a regression nor a classification problem data. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
280    }
281
282    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
283      return new NearestNeighbourRegressionSolution(this, new RegressionProblemData(problemData));
284    }
285    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
286      return new NearestNeighbourClassificationSolution(this, new ClassificationProblemData(problemData));
287    }
288
289    #region events
290    public event EventHandler Changed;
291    private void OnChanged(EventArgs e) {
292      var handlers = Changed;
293      if (handlers != null)
294        handlers(this, e);
295    }
296    #endregion
297
298
299    // BackwardsCompatibility3.3
300    #region Backwards compatible code, remove with 3.4
301
302    private bool isCompatibilityLoaded = false; // new kNN models have the value false, kNN models loaded from disc have the value true
303    [Storable(DefaultValue = true)]
304    public bool IsCompatibilityLoaded {
305      get { return isCompatibilityLoaded; }
306      set { isCompatibilityLoaded = value; }
307    }
308    #endregion
309    #region persistence
310    [Storable]
311    public double KDTreeApproxF {
312      get { return kdTree.approxf; }
313      set { kdTree.approxf = value; }
314    }
315    [Storable]
316    public double[] KDTreeBoxMax {
317      get { return kdTree.boxmax; }
318      set { kdTree.boxmax = value; }
319    }
320    [Storable]
321    public double[] KDTreeBoxMin {
322      get { return kdTree.boxmin; }
323      set { kdTree.boxmin = value; }
324    }
325    [Storable]
326    public double[] KDTreeBuf {
327      get { return kdTree.buf; }
328      set { kdTree.buf = value; }
329    }
330    [Storable]
331    public double[] KDTreeCurBoxMax {
332      get { return kdTree.curboxmax; }
333      set { kdTree.curboxmax = value; }
334    }
335    [Storable]
336    public double[] KDTreeCurBoxMin {
337      get { return kdTree.curboxmin; }
338      set { kdTree.curboxmin = value; }
339    }
340    [Storable]
341    public double KDTreeCurDist {
342      get { return kdTree.curdist; }
343      set { kdTree.curdist = value; }
344    }
345    [Storable]
346    public int KDTreeDebugCounter {
347      get { return kdTree.debugcounter; }
348      set { kdTree.debugcounter = value; }
349    }
350    [Storable]
351    public int[] KDTreeIdx {
352      get { return kdTree.idx; }
353      set { kdTree.idx = value; }
354    }
355    [Storable]
356    public int KDTreeKCur {
357      get { return kdTree.kcur; }
358      set { kdTree.kcur = value; }
359    }
360    [Storable]
361    public int KDTreeKNeeded {
362      get { return kdTree.kneeded; }
363      set { kdTree.kneeded = value; }
364    }
365    [Storable]
366    public int KDTreeN {
367      get { return kdTree.n; }
368      set { kdTree.n = value; }
369    }
370    [Storable]
371    public int[] KDTreeNodes {
372      get { return kdTree.nodes; }
373      set { kdTree.nodes = value; }
374    }
375    [Storable]
376    public int KDTreeNormType {
377      get { return kdTree.normtype; }
378      set { kdTree.normtype = value; }
379    }
380    [Storable]
381    public int KDTreeNX {
382      get { return kdTree.nx; }
383      set { kdTree.nx = value; }
384    }
385    [Storable]
386    public int KDTreeNY {
387      get { return kdTree.ny; }
388      set { kdTree.ny = value; }
389    }
390    [Storable]
391    public double[] KDTreeR {
392      get { return kdTree.r; }
393      set { kdTree.r = value; }
394    }
395    [Storable]
396    public double KDTreeRNeeded {
397      get { return kdTree.rneeded; }
398      set { kdTree.rneeded = value; }
399    }
400    [Storable]
401    public bool KDTreeSelfMatch {
402      get { return kdTree.selfmatch; }
403      set { kdTree.selfmatch = value; }
404    }
405    [Storable]
406    public double[] KDTreeSplits {
407      get { return kdTree.splits; }
408      set { kdTree.splits = value; }
409    }
410    [Storable]
411    public int[] KDTreeTags {
412      get { return kdTree.tags; }
413      set { kdTree.tags = value; }
414    }
415    [Storable]
416    public double[] KDTreeX {
417      get { return kdTree.x; }
418      set { kdTree.x = value; }
419    }
420    [Storable]
421    public double[,] KDTreeXY {
422      get { return kdTree.xy; }
423      set { kdTree.xy = value; }
424    }
425    #endregion
426  }
427}
Note: See TracBrowser for help on using the repository browser.