Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3026_IntegrationIntoSymSpace/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs @ 18242

Last change on this file since 18242 was 18027, checked in by dpiringe, 3 years ago

#3026

  • merged trunk into branch
File size: 9.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HEAL.Attic;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
32  /// Represents a nearest neighbour model for regression and classification
33  /// </summary>
34  [StorableType("04A07DF6-6EB5-4D29-B7AE-5BE204CAF6BC")]
35  [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")]
36  public sealed class NearestNeighbourModel : ClassificationModel, INearestNeighbourModel {
37
38    private alglib.knnmodel model;
39    [Storable]
40    private string SerializedModel {
41      get { alglib.knnserialize(model, out var ser); return ser; }
42      set { if (value != null) alglib.knnunserialize(value, out model); }
43    }
44
45    public override IEnumerable<string> VariablesUsedForPrediction {
46      get { return allowedInputVariables; }
47    }
48
49    [Storable]
50    private string[] allowedInputVariables;
51    [Storable]
52    private double[] classValues;
53    [Storable]
54    private int k;
55    [Storable]
56    private double[] weights;
57    [Storable]
58    private double[] offsets;
59
60    [StorableConstructor]
61    private NearestNeighbourModel(StorableConstructorFlag _) : base(_) { }
62    private NearestNeighbourModel(NearestNeighbourModel original, Cloner cloner)
63      : base(original, cloner) {
64      if (original.model != null)
65        model = (alglib.knnmodel)original.model.make_copy();
66      k = original.k;
67      weights = new double[original.weights.Length];
68      Array.Copy(original.weights, weights, weights.Length);
69      offsets = new double[original.offsets.Length];
70      Array.Copy(original.offsets, this.offsets, this.offsets.Length);
71
72      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
73      if (original.classValues != null)
74        this.classValues = (double[])original.classValues.Clone();
75    }
76    public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<double> weights = null, double[] classValues = null)
77      : base(targetVariable) {
78      Name = ItemName;
79      Description = ItemDescription;
80      this.k = k;
81      this.allowedInputVariables = allowedInputVariables.ToArray();
82      double[,] inputMatrix;
83      this.offsets = this.allowedInputVariables
84        .Select(name => dataset.GetDoubleValues(name, rows).Average() * -1)
85        .Concat(new double[] { 0 }) // no offset for target variable
86        .ToArray();
87      if (weights == null) {
88        // automatic determination of weights (all features should have variance = 1)
89        this.weights = this.allowedInputVariables
90          .Select(name => {
91            var pop = dataset.GetDoubleValues(name, rows).StandardDeviationPop();
92            return pop.IsAlmost(0) ? 1.0 : 1.0 / pop;
93          })
94          .Concat(new double[] { 1.0 }) // no scaling for target variable
95          .ToArray();
96      } else {
97        // user specified weights (+ 1 for target)
98        this.weights = weights.Concat(new double[] { 1.0 }).ToArray();
99        if (this.weights.Length - 1 != this.allowedInputVariables.Length)
100          throw new ArgumentException("The number of elements in the weight vector must match the number of input variables");
101      }
102      inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights);
103
104      if (inputMatrix.ContainsNanOrInfinity())
105        throw new NotSupportedException(
106          "Nearest neighbour model does not support NaN or infinity values in the input dataset.");
107
108      var nRows = inputMatrix.GetLength(0);
109      var nFeatures = inputMatrix.GetLength(1) - 1;
110
111      if (classValues != null) {
112        this.classValues = (double[])classValues.Clone();
113        int nClasses = classValues.Length;
114        // map original class values to values [0..nClasses-1]
115        var classIndices = new Dictionary<double, double>();
116        for (int i = 0; i < nClasses; i++)
117          classIndices[classValues[i]] = i;
118
119        for (int row = 0; row < nRows; row++) {
120          inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
121        }
122      }
123
124      alglib.knnbuildercreate(out var knnbuilder);
125      if (classValues == null) {
126        alglib.knnbuildersetdatasetreg(knnbuilder, inputMatrix, nRows, nFeatures, nout: 1);
127      } else {
128        alglib.knnbuildersetdatasetcls(knnbuilder, inputMatrix, nRows, nFeatures, classValues.Length);
129      }
130      alglib.knnbuilderbuildknnmodel(knnbuilder, k, 0.0, out model, out var report); // eps=0 (exact k-nn search is performed)
131
132    }
133
134    private static double[,] CreateScaledData(IDataset dataset, IEnumerable<string> variables, IEnumerable<int> rows, double[] offsets, double[] factors) {
135      var transforms =
136        variables.Select(
137          (_, colIdx) =>
138            new LinearTransformation(variables) { Addend = offsets[colIdx] * factors[colIdx], Multiplier = factors[colIdx] });
139      return dataset.ToArray(variables, transforms, rows);
140    }
141
142    public override IDeepCloneable Clone(Cloner cloner) {
143      return new NearestNeighbourModel(this, cloner);
144    }
145
146    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
147      double[,] inputData;
148      inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
149
150      int n = inputData.GetLength(0);
151      int columns = inputData.GetLength(1);
152      double[] x = new double[columns];
153
154      alglib.knncreatebuffer(model, out var buf);
155      var y = new double[1];
156      for (int row = 0; row < n; row++) {
157        for (int column = 0; column < columns; column++) {
158          x[column] = inputData[row, column];
159        }
160        alglib.knntsprocess(model, buf, x, ref y); // thread-safe process
161        yield return y[0];
162      }
163    }
164
165    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
166      if (classValues == null) throw new InvalidOperationException("No class values are defined.");
167      double[,] inputData;
168      inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
169
170      int n = inputData.GetLength(0);
171      int columns = inputData.GetLength(1);
172      double[] x = new double[columns];
173
174      alglib.knncreatebuffer(model, out var buf);
175      var y = new double[classValues.Length];
176      for (int row = 0; row < n; row++) {
177        for (int column = 0; column < columns; column++) {
178          x[column] = inputData[row, column];
179        }
180        alglib.knntsprocess(model, buf, x, ref y); // thread-safe process
181        // find most probably class
182        var maxC = 0;
183        for (int i = 1; i < y.Length; i++)
184          if (maxC < y[i]) maxC = i;
185        yield return classValues[maxC];
186      }
187    }
188
189
190    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
191      return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
192    }
193
194    public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
195      if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
196
197      var regressionProblemData = problemData as IRegressionProblemData;
198      if (regressionProblemData != null)
199        return IsProblemDataCompatible(regressionProblemData, out errorMessage);
200
201      var classificationProblemData = problemData as IClassificationProblemData;
202      if (classificationProblemData != null)
203        return IsProblemDataCompatible(classificationProblemData, out errorMessage);
204
205      throw new ArgumentException("The problem data is not compatible with this nearest neighbour model. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
206    }
207
208    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
209      return new NearestNeighbourRegressionSolution(this, new RegressionProblemData(problemData));
210    }
211    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
212      return new NearestNeighbourClassificationSolution(this, new ClassificationProblemData(problemData));
213    }
214  }
215}
Note: See TracBrowser for help on using the repository browser.