Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs @ 18079

Last change on this file since 18079 was 17934, checked in by gkronber, 4 years ago

#3117: fixed build fail

File size: 9.2 KB
RevLine 
[16491]1#region License Information
[6583]2/* HeuristicLab
[17180]3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[6583]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
[16565]27using HEAL.Attic;
[6583]28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
32  /// Represents a nearest neighbour model for regression and classification
33  /// </summary>
[17931]34  [StorableType("04A07DF6-6EB5-4D29-B7AE-5BE204CAF6BC")]
[8465]35  [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")]
[13941]36  public sealed class NearestNeighbourModel : ClassificationModel, INearestNeighbourModel {
[6583]37
[17931]38    private alglib.knnmodel model;
39    [Storable]
40    private string SerializedModel {
41      get { alglib.knnserialize(model, out var ser); return ser; }
42      set { if (value != null) alglib.knnunserialize(value, out model); }
[6583]43    }
44
[13941]45    public override IEnumerable<string> VariablesUsedForPrediction {
[13921]46      get { return allowedInputVariables; }
47    }
48
[6583]49    [Storable]
50    private string[] allowedInputVariables;
51    [Storable]
52    private double[] classValues;
53    [Storable]
54    private int k;
[17931]55    [Storable]
56    private double[] weights;
57    [Storable]
58    private double[] offsets;
[8465]59
[6583]60    [StorableConstructor]
[17931]61    private NearestNeighbourModel(StorableConstructorFlag _) : base(_) { }
[6583]62    private NearestNeighbourModel(NearestNeighbourModel original, Cloner cloner)
63      : base(original, cloner) {
[17931]64      if (original.model != null)
65        model = (alglib.knnmodel)original.model.make_copy();
[6583]66      k = original.k;
[17931]67      weights = new double[original.weights.Length];
68      Array.Copy(original.weights, weights, weights.Length);
69      offsets = new double[original.offsets.Length];
70      Array.Copy(original.offsets, this.offsets, this.offsets.Length);
71
[6583]72      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
73      if (original.classValues != null)
74        this.classValues = (double[])original.classValues.Clone();
75    }
[17931]76    public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<double> weights = null, double[] classValues = null)
[13941]77      : base(targetVariable) {
[8467]78      Name = ItemName;
79      Description = ItemDescription;
[6583]80      this.k = k;
81      this.allowedInputVariables = allowedInputVariables.ToArray();
[14235]82      double[,] inputMatrix;
[17931]83      this.offsets = this.allowedInputVariables
84        .Select(name => dataset.GetDoubleValues(name, rows).Average() * -1)
85        .Concat(new double[] { 0 }) // no offset for target variable
86        .ToArray();
87      if (weights == null) {
88        // automatic determination of weights (all features should have variance = 1)
89        this.weights = this.allowedInputVariables
90          .Select(name => {
91            var pop = dataset.GetDoubleValues(name, rows).StandardDeviationPop();
92            return pop.IsAlmost(0) ? 1.0 : 1.0 / pop;
93          })
94          .Concat(new double[] { 1.0 }) // no scaling for target variable
95          .ToArray();
[14235]96      } else {
[17931]97        // user specified weights (+ 1 for target)
98        this.weights = weights.Concat(new double[] { 1.0 }).ToArray();
99        if (this.weights.Length - 1 != this.allowedInputVariables.Length)
100          throw new ArgumentException("The number of elements in the weight vector must match the number of input variables");
[14235]101      }
[17931]102      inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights);
[8465]103
[15786]104      if (inputMatrix.ContainsNanOrInfinity())
[8465]105        throw new NotSupportedException(
[14826]106          "Nearest neighbour model does not support NaN or infinity values in the input dataset.");
[8465]107
108      var nRows = inputMatrix.GetLength(0);
109      var nFeatures = inputMatrix.GetLength(1) - 1;
110
111      if (classValues != null) {
[6583]112        this.classValues = (double[])classValues.Clone();
[8465]113        int nClasses = classValues.Length;
114        // map original class values to values [0..nClasses-1]
115        var classIndices = new Dictionary<double, double>();
116        for (int i = 0; i < nClasses; i++)
117          classIndices[classValues[i]] = i;
118
119        for (int row = 0; row < nRows; row++) {
120          inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
121        }
122      }
[17931]123
124      alglib.knnbuildercreate(out var knnbuilder);
125      if (classValues == null) {
126        alglib.knnbuildersetdatasetreg(knnbuilder, inputMatrix, nRows, nFeatures, nout: 1);
127      } else {
128        alglib.knnbuildersetdatasetcls(knnbuilder, inputMatrix, nRows, nFeatures, classValues.Length);
129      }
[17934]130      alglib.knnbuilderbuildknnmodel(knnbuilder, k, 0.0, out model, out var report); // eps=0 (exact k-nn search is performed)
[17931]131
[6583]132    }
133
[14235]134    private static double[,] CreateScaledData(IDataset dataset, IEnumerable<string> variables, IEnumerable<int> rows, double[] offsets, double[] factors) {
[14843]135      var transforms =
136        variables.Select(
137          (_, colIdx) =>
138            new LinearTransformation(variables) { Addend = offsets[colIdx] * factors[colIdx], Multiplier = factors[colIdx] });
139      return dataset.ToArray(variables, transforms, rows);
[14235]140    }
141
[6583]142    public override IDeepCloneable Clone(Cloner cloner) {
143      return new NearestNeighbourModel(this, cloner);
144    }
145
[12509]146    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
[14235]147      double[,] inputData;
[17931]148      inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
[6583]149
150      int n = inputData.GetLength(0);
151      int columns = inputData.GetLength(1);
152      double[] x = new double[columns];
153
[17931]154      alglib.knncreatebuffer(model, out var buf);
155      var y = new double[1];
[6583]156      for (int row = 0; row < n; row++) {
157        for (int column = 0; column < columns; column++) {
158          x[column] = inputData[row, column];
159        }
[17931]160        alglib.knntsprocess(model, buf, x, ref y); // thread-safe process
161        yield return y[0];
[6583]162      }
163    }
164
[13941]165    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
[8465]166      if (classValues == null) throw new InvalidOperationException("No class values are defined.");
[14235]167      double[,] inputData;
[17931]168      inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
169
[6583]170      int n = inputData.GetLength(0);
171      int columns = inputData.GetLength(1);
172      double[] x = new double[columns];
173
[17931]174      alglib.knncreatebuffer(model, out var buf);
175      var y = new double[classValues.Length];
[6583]176      for (int row = 0; row < n; row++) {
177        for (int column = 0; column < columns; column++) {
178          x[column] = inputData[row, column];
179        }
[17931]180        alglib.knntsprocess(model, buf, x, ref y); // thread-safe process
181        // find most probably class
182        var maxC = 0;
183        for (int i = 1; i < y.Length; i++)
184          if (maxC < y[i]) maxC = i;
185        yield return classValues[maxC];
[6583]186      }
187    }
188
[13941]189
[16243]190    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
191      return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
192    }
193
194    public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
195      if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
196
197      var regressionProblemData = problemData as IRegressionProblemData;
198      if (regressionProblemData != null)
199        return IsProblemDataCompatible(regressionProblemData, out errorMessage);
200
201      var classificationProblemData = problemData as IClassificationProblemData;
202      if (classificationProblemData != null)
203        return IsProblemDataCompatible(classificationProblemData, out errorMessage);
204
[16763]205      throw new ArgumentException("The problem data is not compatible with this nearest neighbour model. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
[16243]206    }
207
[6603]208    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
[13941]209      return new NearestNeighbourRegressionSolution(this, new RegressionProblemData(problemData));
[6603]210    }
[13941]211    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
212      return new NearestNeighbourClassificationSolution(this, new ClassificationProblemData(problemData));
[6604]213    }
[6583]214  }
215}
Note: See TracBrowser for help on using the repository browser.