1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  *


5  * This file is part of HeuristicLab.


6  *


7  * HeuristicLab is free software: you can redistribute it and/or modify


8  * it under the terms of the GNU General Public License as published by


9  * the Free Software Foundation, either version 3 of the License, or


10  * (at your option) any later version.


11  *


12  * HeuristicLab is distributed in the hope that it will be useful,


13  * but WITHOUT ANY WARRANTY; without even the implied warranty of


14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


15  * GNU General Public License for more details.


16  *


17  * You should have received a copy of the GNU General Public License


18  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


19  */


20  #endregion


21 


22  using System;


23  using System.Collections.Generic;


24  using System.IO;


25  using System.Linq;


26  using System.Text;


27  using HeuristicLab.Common;


28  using HeuristicLab.Core;


29  using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;


30  using HeuristicLab.Problems.DataAnalysis;


31  using SVM;


32 


33  namespace HeuristicLab.Algorithms.DataAnalysis {


34  /// <summary>


35  /// Represents a nearest neighbour model for regression and classification


36  /// </summary>


37  [StorableClass]


38  [Item("NearestNeighbourModel", "Represents a neural network for regression and classification.")]


39  public sealed class NearestNeighbourModel : NamedItem, INearestNeighbourModel {


40 


41  private alglib.nearestneighbor.kdtree kdTree;


42  public alglib.nearestneighbor.kdtree KDTree {


43  get { return kdTree; }


44  set {


45  if (value != kdTree) {


46  if (value == null) throw new ArgumentNullException();


47  kdTree = value;


48  OnChanged(EventArgs.Empty);


49  }


50  }


51  }


52 


53  [Storable]


54  private string targetVariable;


55  [Storable]


56  private string[] allowedInputVariables;


57  [Storable]


58  private double[] classValues;


59  [Storable]


60  private int k;


61  [StorableConstructor]


62  private NearestNeighbourModel(bool deserializing)


63  : base(deserializing) {


64  if (deserializing)


65  kdTree = new alglib.nearestneighbor.kdtree();


66  }


67  private NearestNeighbourModel(NearestNeighbourModel original, Cloner cloner)


68  : base(original, cloner) {


69  kdTree = new alglib.nearestneighbor.kdtree();


70  kdTree.approxf = original.kdTree.approxf;


71  kdTree.boxmax = (double[])original.kdTree.boxmax.Clone();


72  kdTree.boxmin = (double[])original.kdTree.boxmin.Clone();


73  kdTree.buf = (double[])original.kdTree.buf.Clone();


74  kdTree.curboxmax = (double[])original.kdTree.curboxmax.Clone();


75  kdTree.curboxmin = (double[])original.kdTree.curboxmin.Clone();


76  kdTree.curdist = original.kdTree.curdist;


77  kdTree.debugcounter = original.kdTree.debugcounter;


78  kdTree.distmatrixtype = original.kdTree.distmatrixtype;


79  kdTree.idx = (int[])original.kdTree.idx.Clone();


80  kdTree.kcur = original.kdTree.kcur;


81  kdTree.kneeded = original.kdTree.kneeded;


82  kdTree.n = original.kdTree.n;


83  kdTree.nodes = (int[])original.kdTree.nodes.Clone();


84  kdTree.normtype = original.kdTree.normtype;


85  kdTree.nx = original.kdTree.nx;


86  kdTree.ny = original.kdTree.ny;


87  kdTree.r = (double[])original.kdTree.r.Clone();


88  kdTree.rneeded = original.kdTree.rneeded;


89  kdTree.selfmatch = original.kdTree.selfmatch;


90  kdTree.splits = (double[])original.kdTree.splits.Clone();


91  kdTree.tags = (int[])original.kdTree.tags.Clone();


92  kdTree.x = (double[])original.kdTree.x.Clone();


93  kdTree.xy = (double[,])original.kdTree.xy.Clone();


94 


95  k = original.k;


96  targetVariable = original.targetVariable;


97  allowedInputVariables = (string[])original.allowedInputVariables.Clone();


98  if (original.classValues != null)


99  this.classValues = (double[])original.classValues.Clone();


100  }


101  public NearestNeighbourModel(alglib.nearestneighbor.kdtree kdTree, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)


102  : base() {


103  this.name = ItemName;


104  this.description = ItemDescription;


105  this.kdTree = kdTree;


106  this.k = k;


107  this.targetVariable = targetVariable;


108  this.allowedInputVariables = allowedInputVariables.ToArray();


109  if (classValues != null)


110  this.classValues = (double[])classValues.Clone();


111  }


112 


113  public override IDeepCloneable Clone(Cloner cloner) {


114  return new NearestNeighbourModel(this, cloner);


115  }


116 


117  public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {


118  double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);


119 


120  int n = inputData.GetLength(0);


121  int columns = inputData.GetLength(1);


122  double[] x = new double[columns];


123  double[] y = new double[1];


124  double[] dists = new double[k];


125  double[,] neighbours = new double[k, columns + 1];


126 


127  for (int row = 0; row < n; row++) {


128  for (int column = 0; column < columns; column++) {


129  x[column] = inputData[row, column];


130  }


131  int actNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);


132  alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);


133  alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);


134 


135  double distanceWeightedValue = 0.0;


136  double distsSum = 0.0;


137  for (int i = 0; i < actNeighbours; i++) {


138  distanceWeightedValue += neighbours[i, columns] / dists[i];


139  distsSum += 1.0 / dists[i];


140  }


141  yield return distanceWeightedValue / distsSum;


142  }


143  }


144 


145  public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {


146  double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);


147 


148  int n = inputData.GetLength(0);


149  int columns = inputData.GetLength(1);


150  double[] x = new double[columns];


151  int[] y = new int[classValues.Length];


152  double[] dists = new double[k];


153  double[,] neighbours = new double[k, columns + 1];


154 


155  for (int row = 0; row < n; row++) {


156  for (int column = 0; column < columns; column++) {


157  x[column] = inputData[row, column];


158  }


159  int actNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);


160  alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);


161  alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);


162 


163  Array.Clear(y, 0, y.Length);


164  for (int i = 0; i < actNeighbours; i++) {


165  int classValue = (int)Math.Round(neighbours[i, columns]);


166  y[classValue]++;


167  }


168 


169  // find class for with the largest probability value


170  int maxProbClassIndex = 0;


171  double maxProb = y[0];


172  for (int i = 1; i < y.Length; i++) {


173  if (maxProb < y[i]) {


174  maxProb = y[i];


175  maxProbClassIndex = i;


176  }


177  }


178  yield return classValues[maxProbClassIndex];


179  }


180  }


181 


182  #region events


183  public event EventHandler Changed;


184  private void OnChanged(EventArgs e) {


185  var handlers = Changed;


186  if (handlers != null)


187  handlers(this, e);


188  }


189  #endregion


190 


191  #region persistence


192  // not implemented yet


193  #endregion


194  }


195  }

