Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModelAlglib_3_7.cs @ 18065

Last change on this file since 18065 was 17931, checked in by gkronber, 3 years ago

#3117: update alglib to version 3.17

File size: 17.2 KB
RevLine 
[16491]1#region License Information
[6583]2/* HeuristicLab
[17180]3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[6583]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
[17931]22extern alias alglib_3_7;
[6583]23using System;
24using System.Collections.Generic;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
[16565]28using HEAL.Attic;
[6583]29using HeuristicLab.Problems.DataAnalysis;
30
31namespace HeuristicLab.Algorithms.DataAnalysis {
32  /// <summary>
33  /// Represents a nearest neighbour model for regression and classification
34  /// </summary>
[16565]35  [StorableType("A76C0823-3077-4ACE-8A40-E9B717C7DB60")]
[8465]36  [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")]
[17931]37  [Obsolete("This version uses alglib version 3.7. Use NearestNeighbourModel instead.")]
38  public sealed class NearestNeighbourModelAlglib_3_7 : ClassificationModel, INearestNeighbourModel {
[6583]39
[14322]40    private readonly object kdTreeLockObject = new object();
[16491]41
[17931]42    private alglib_3_7.alglib.nearestneighbor.kdtree kdTree;
43    public alglib_3_7.alglib.nearestneighbor.kdtree KDTree {
[6583]44      get { return kdTree; }
45      set {
46        if (value != kdTree) {
47          if (value == null) throw new ArgumentNullException();
48          kdTree = value;
49          OnChanged(EventArgs.Empty);
50        }
51      }
52    }
53
[13941]54    public override IEnumerable<string> VariablesUsedForPrediction {
[13921]55      get { return allowedInputVariables; }
56    }
57
[6583]58    [Storable]
59    private string[] allowedInputVariables;
60    [Storable]
61    private double[] classValues;
62    [Storable]
63    private int k;
[16491]64    [Storable(DefaultValue = false)]
65    private bool selfMatch;
[14235]66    [Storable(DefaultValue = null)]
67    private double[] weights; // not set for old versions loaded from disk
68    [Storable(DefaultValue = null)]
69    private double[] offsets; // not set for old versions loaded from disk
[8465]70
[6583]71    [StorableConstructor]
[17931]72    private NearestNeighbourModelAlglib_3_7(StorableConstructorFlag _) : base(_) {
73      kdTree = new alglib_3_7.alglib.nearestneighbor.kdtree();
[6583]74    }
[17931]75    private NearestNeighbourModelAlglib_3_7(NearestNeighbourModelAlglib_3_7 original, Cloner cloner)
[6583]76      : base(original, cloner) {
[17931]77      kdTree = new alglib_3_7.alglib.nearestneighbor.kdtree();
[6583]78      kdTree.approxf = original.kdTree.approxf;
79      kdTree.boxmax = (double[])original.kdTree.boxmax.Clone();
80      kdTree.boxmin = (double[])original.kdTree.boxmin.Clone();
81      kdTree.buf = (double[])original.kdTree.buf.Clone();
82      kdTree.curboxmax = (double[])original.kdTree.curboxmax.Clone();
83      kdTree.curboxmin = (double[])original.kdTree.curboxmin.Clone();
84      kdTree.curdist = original.kdTree.curdist;
85      kdTree.debugcounter = original.kdTree.debugcounter;
86      kdTree.idx = (int[])original.kdTree.idx.Clone();
87      kdTree.kcur = original.kdTree.kcur;
88      kdTree.kneeded = original.kdTree.kneeded;
89      kdTree.n = original.kdTree.n;
90      kdTree.nodes = (int[])original.kdTree.nodes.Clone();
91      kdTree.normtype = original.kdTree.normtype;
92      kdTree.nx = original.kdTree.nx;
93      kdTree.ny = original.kdTree.ny;
94      kdTree.r = (double[])original.kdTree.r.Clone();
95      kdTree.rneeded = original.kdTree.rneeded;
96      kdTree.selfmatch = original.kdTree.selfmatch;
97      kdTree.splits = (double[])original.kdTree.splits.Clone();
98      kdTree.tags = (int[])original.kdTree.tags.Clone();
99      kdTree.x = (double[])original.kdTree.x.Clone();
100      kdTree.xy = (double[,])original.kdTree.xy.Clone();
[16491]101      selfMatch = original.selfMatch;
[6583]102      k = original.k;
[14235]103      isCompatibilityLoaded = original.IsCompatibilityLoaded;
104      if (!IsCompatibilityLoaded) {
105        weights = new double[original.weights.Length];
106        Array.Copy(original.weights, weights, weights.Length);
107        offsets = new double[original.offsets.Length];
108        Array.Copy(original.offsets, this.offsets, this.offsets.Length);
109      }
[6583]110      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
111      if (original.classValues != null)
112        this.classValues = (double[])original.classValues.Clone();
113    }
[17931]114    public NearestNeighbourModelAlglib_3_7(IDataset dataset, IEnumerable<int> rows, int k, bool selfMatch, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<double> weights = null, double[] classValues = null)
[13941]115      : base(targetVariable) {
[8467]116      Name = ItemName;
117      Description = ItemDescription;
[16491]118      this.selfMatch = selfMatch;
[6583]119      this.k = k;
120      this.allowedInputVariables = allowedInputVariables.ToArray();
[14235]121      double[,] inputMatrix;
122      if (IsCompatibilityLoaded) {
123        // no scaling
[14843]124        inputMatrix = dataset.ToArray(
[14235]125          this.allowedInputVariables.Concat(new string[] { targetVariable }),
126          rows);
127      } else {
128        this.offsets = this.allowedInputVariables
129          .Select(name => dataset.GetDoubleValues(name, rows).Average() * -1)
130          .Concat(new double[] { 0 }) // no offset for target variable
131          .ToArray();
132        if (weights == null) {
133          // automatic determination of weights (all features should have variance = 1)
134          this.weights = this.allowedInputVariables
[16086]135            .Select(name => {
136              var pop = dataset.GetDoubleValues(name, rows).StandardDeviationPop();
[16491]137              return pop.IsAlmost(0) ? 1.0 : 1.0 / pop;
[16086]138            })
[14235]139            .Concat(new double[] { 1.0 }) // no scaling for target variable
140            .ToArray();
141        } else {
142          // user specified weights (+ 1 for target)
143          this.weights = weights.Concat(new double[] { 1.0 }).ToArray();
144          if (this.weights.Length - 1 != this.allowedInputVariables.Length)
145            throw new ArgumentException("The number of elements in the weight vector must match the number of input variables");
146        }
147        inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights);
148      }
[8465]149
[15786]150      if (inputMatrix.ContainsNanOrInfinity())
[8465]151        throw new NotSupportedException(
[14826]152          "Nearest neighbour model does not support NaN or infinity values in the input dataset.");
[8465]153
[17931]154      this.kdTree = new alglib_3_7.alglib.nearestneighbor.kdtree();
[8465]155
156      var nRows = inputMatrix.GetLength(0);
157      var nFeatures = inputMatrix.GetLength(1) - 1;
158
159      if (classValues != null) {
[6583]160        this.classValues = (double[])classValues.Clone();
[8465]161        int nClasses = classValues.Length;
162        // map original class values to values [0..nClasses-1]
163        var classIndices = new Dictionary<double, double>();
164        for (int i = 0; i < nClasses; i++)
165          classIndices[classValues[i]] = i;
166
167        for (int row = 0; row < nRows; row++) {
168          inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
169        }
170      }
[17931]171      alglib_3_7.alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdTree);
[6583]172    }
173
[14235]174    private static double[,] CreateScaledData(IDataset dataset, IEnumerable<string> variables, IEnumerable<int> rows, double[] offsets, double[] factors) {
[14843]175      var transforms =
176        variables.Select(
177          (_, colIdx) =>
178            new LinearTransformation(variables) { Addend = offsets[colIdx] * factors[colIdx], Multiplier = factors[colIdx] });
179      return dataset.ToArray(variables, transforms, rows);
[14235]180    }
181
[6583]182    public override IDeepCloneable Clone(Cloner cloner) {
[17931]183      return new NearestNeighbourModelAlglib_3_7(this, cloner);
[6583]184    }
185
[12509]186    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
[14235]187      double[,] inputData;
188      if (IsCompatibilityLoaded) {
[14843]189        inputData = dataset.ToArray(allowedInputVariables, rows);
[14235]190      } else {
191        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
192      }
[6583]193
194      int n = inputData.GetLength(0);
195      int columns = inputData.GetLength(1);
196      double[] x = new double[columns];
197      double[] dists = new double[k];
198      double[,] neighbours = new double[k, columns + 1];
199
200      for (int row = 0; row < n; row++) {
201        for (int column = 0; column < columns; column++) {
202          x[column] = inputData[row, column];
203        }
[14236]204        int numNeighbours;
[14314]205        lock (kdTreeLockObject) { // gkronber: the following calls change the kdTree data structure
[17931]206          numNeighbours = alglib_3_7.alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, selfMatch);
207          alglib_3_7.alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
208          alglib_3_7.alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
[14236]209        }
[16491]210        if (selfMatch) {
211          // weights for neighbours are 1/d.
212          // override distances (=0) of exact matches using 1% of the distance of the next closest non-self-match neighbour -> selfmatches weight 100x more than the next closest neighbor.
213          // if all k neighbours are selfmatches then they all have weight 0.01.
214          double minDist = dists[0] + 1;
215          for (int i = 0; i < numNeighbours; i++) {
216            if ((minDist > dists[i]) && (dists[i] != 0)) {
217              minDist = dists[i];
218            }
219          }
220          minDist /= 100.0;
221          for (int i = 0; i < numNeighbours; i++) {
222            if (dists[i] == 0) {
223              dists[i] = minDist;
224            }
225          }
226        }
[6583]227        double distanceWeightedValue = 0.0;
228        double distsSum = 0.0;
[14236]229        for (int i = 0; i < numNeighbours; i++) {
[6583]230          distanceWeightedValue += neighbours[i, columns] / dists[i];
231          distsSum += 1.0 / dists[i];
232        }
233        yield return distanceWeightedValue / distsSum;
234      }
235    }
236
[13941]237    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
[8465]238      if (classValues == null) throw new InvalidOperationException("No class values are defined.");
[14235]239      double[,] inputData;
240      if (IsCompatibilityLoaded) {
[14843]241        inputData = dataset.ToArray(allowedInputVariables, rows);
[14235]242      } else {
243        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
244      }
[6583]245      int n = inputData.GetLength(0);
246      int columns = inputData.GetLength(1);
247      double[] x = new double[columns];
248      int[] y = new int[classValues.Length];
249      double[] dists = new double[k];
250      double[,] neighbours = new double[k, columns + 1];
251
252      for (int row = 0; row < n; row++) {
253        for (int column = 0; column < columns; column++) {
254          x[column] = inputData[row, column];
255        }
[14236]256        int numNeighbours;
[14314]257        lock (kdTreeLockObject) {
[14236]258          // gkronber: the following calls change the kdTree data structure
[17931]259          numNeighbours = alglib_3_7.alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, selfMatch);
260          alglib_3_7.alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
261          alglib_3_7.alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
[14236]262        }
[6583]263        Array.Clear(y, 0, y.Length);
[14236]264        for (int i = 0; i < numNeighbours; i++) {
[6583]265          int classValue = (int)Math.Round(neighbours[i, columns]);
266          y[classValue]++;
267        }
268
269        // find class for with the largest probability value
270        int maxProbClassIndex = 0;
271        double maxProb = y[0];
272        for (int i = 1; i < y.Length; i++) {
273          if (maxProb < y[i]) {
274            maxProb = y[i];
275            maxProbClassIndex = i;
276          }
277        }
278        yield return classValues[maxProbClassIndex];
279      }
280    }
281
[13941]282
[16243]283    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
284      return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
285    }
286
287    public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
288      if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
289
290      var regressionProblemData = problemData as IRegressionProblemData;
291      if (regressionProblemData != null)
292        return IsProblemDataCompatible(regressionProblemData, out errorMessage);
293
294      var classificationProblemData = problemData as IClassificationProblemData;
295      if (classificationProblemData != null)
296        return IsProblemDataCompatible(classificationProblemData, out errorMessage);
297
[16763]298      throw new ArgumentException("The problem data is not compatible with this nearest neighbour model. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
[16243]299    }
300
[6603]301    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
[13941]302      return new NearestNeighbourRegressionSolution(this, new RegressionProblemData(problemData));
[6603]303    }
[13941]304    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
305      return new NearestNeighbourClassificationSolution(this, new ClassificationProblemData(problemData));
[6604]306    }
[6603]307
[6583]308    #region events
309    public event EventHandler Changed;
310    private void OnChanged(EventArgs e) {
311      var handlers = Changed;
312      if (handlers != null)
313        handlers(this, e);
314    }
315    #endregion
316
[14235]317
318    // BackwardsCompatibility3.3
319    #region Backwards compatible code, remove with 3.4
320
321    private bool isCompatibilityLoaded = false; // new kNN models have the value false, kNN models loaded from disc have the value true
322    [Storable(DefaultValue = true)]
323    public bool IsCompatibilityLoaded {
324      get { return isCompatibilityLoaded; }
325      set { isCompatibilityLoaded = value; }
326    }
327    #endregion
[6583]328    #region persistence
[6584]329    [Storable]
330    public double KDTreeApproxF {
331      get { return kdTree.approxf; }
332      set { kdTree.approxf = value; }
333    }
334    [Storable]
335    public double[] KDTreeBoxMax {
336      get { return kdTree.boxmax; }
337      set { kdTree.boxmax = value; }
338    }
339    [Storable]
340    public double[] KDTreeBoxMin {
341      get { return kdTree.boxmin; }
342      set { kdTree.boxmin = value; }
343    }
344    [Storable]
345    public double[] KDTreeBuf {
346      get { return kdTree.buf; }
347      set { kdTree.buf = value; }
348    }
349    [Storable]
350    public double[] KDTreeCurBoxMax {
351      get { return kdTree.curboxmax; }
352      set { kdTree.curboxmax = value; }
353    }
354    [Storable]
355    public double[] KDTreeCurBoxMin {
356      get { return kdTree.curboxmin; }
357      set { kdTree.curboxmin = value; }
358    }
359    [Storable]
360    public double KDTreeCurDist {
361      get { return kdTree.curdist; }
362      set { kdTree.curdist = value; }
363    }
364    [Storable]
365    public int KDTreeDebugCounter {
366      get { return kdTree.debugcounter; }
367      set { kdTree.debugcounter = value; }
368    }
369    [Storable]
370    public int[] KDTreeIdx {
371      get { return kdTree.idx; }
372      set { kdTree.idx = value; }
373    }
374    [Storable]
375    public int KDTreeKCur {
376      get { return kdTree.kcur; }
377      set { kdTree.kcur = value; }
378    }
379    [Storable]
380    public int KDTreeKNeeded {
381      get { return kdTree.kneeded; }
382      set { kdTree.kneeded = value; }
383    }
384    [Storable]
385    public int KDTreeN {
386      get { return kdTree.n; }
387      set { kdTree.n = value; }
388    }
389    [Storable]
390    public int[] KDTreeNodes {
391      get { return kdTree.nodes; }
392      set { kdTree.nodes = value; }
393    }
394    [Storable]
395    public int KDTreeNormType {
396      get { return kdTree.normtype; }
397      set { kdTree.normtype = value; }
398    }
399    [Storable]
400    public int KDTreeNX {
401      get { return kdTree.nx; }
402      set { kdTree.nx = value; }
403    }
404    [Storable]
405    public int KDTreeNY {
406      get { return kdTree.ny; }
407      set { kdTree.ny = value; }
408    }
409    [Storable]
410    public double[] KDTreeR {
411      get { return kdTree.r; }
412      set { kdTree.r = value; }
413    }
414    [Storable]
415    public double KDTreeRNeeded {
416      get { return kdTree.rneeded; }
417      set { kdTree.rneeded = value; }
418    }
419    [Storable]
420    public bool KDTreeSelfMatch {
421      get { return kdTree.selfmatch; }
422      set { kdTree.selfmatch = value; }
423    }
424    [Storable]
425    public double[] KDTreeSplits {
426      get { return kdTree.splits; }
427      set { kdTree.splits = value; }
428    }
429    [Storable]
430    public int[] KDTreeTags {
431      get { return kdTree.tags; }
432      set { kdTree.tags = value; }
433    }
434    [Storable]
435    public double[] KDTreeX {
436      get { return kdTree.x; }
437      set { kdTree.x = value; }
438    }
439    [Storable]
440    public double[,] KDTreeXY {
441      get { return kdTree.xy; }
442      set { kdTree.xy = value; }
443    }
[6583]444    #endregion
445  }
446}
Note: See TracBrowser for help on using the repository browser.