Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/10/12 14:57:21 (12 years ago)
Author:
abeham
Message:

#1913: Changed k-NN to move model representation (kdTree) into the model object

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs

    r7294 r8465  
    3333  /// </summary>
    3434  [StorableClass]
    35   [Item("NearestNeighbourModel", "Represents a neural network for regression and classification.")]
     35  [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")]
    3636  public sealed class NearestNeighbourModel : NamedItem, INearestNeighbourModel {
    3737
     
    5656    [Storable]
    5757    private int k;
     58
    5859    [StorableConstructor]
    5960    private NearestNeighbourModel(bool deserializing)
     
    9596        this.classValues = (double[])original.classValues.Clone();
    9697    }
    97     public NearestNeighbourModel(alglib.nearestneighbor.kdtree kdTree, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
    98       : base() {
    99       this.name = ItemName;
    100       this.description = ItemDescription;
    101       this.kdTree = kdTree;
     98    public NearestNeighbourModel(Dataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null) {
    10299      this.k = k;
    103100      this.targetVariable = targetVariable;
    104101      this.allowedInputVariables = allowedInputVariables.ToArray();
    105       if (classValues != null)
     102
     103      var inputMatrix = AlglibUtil.PrepareInputMatrix(dataset,
     104                                   allowedInputVariables.Concat(new string[] { targetVariable }),
     105                                   rows);
     106
     107      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
     108        throw new NotSupportedException(
     109          "Nearest neighbour classification does not support NaN or infinity values in the input dataset.");
     110
     111      this.kdTree = new alglib.nearestneighbor.kdtree();
     112
     113      var nRows = inputMatrix.GetLength(0);
     114      var nFeatures = inputMatrix.GetLength(1) - 1;
     115
     116      if (classValues != null) {
    106117        this.classValues = (double[])classValues.Clone();
     118        int nClasses = classValues.Length;
     119        // map original class values to values [0..nClasses-1]
     120        var classIndices = new Dictionary<double, double>();
     121        for (int i = 0; i < nClasses; i++)
     122          classIndices[classValues[i]] = i;
     123
     124        for (int row = 0; row < nRows; row++) {
     125          inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
     126        }
     127      }
     128      alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdTree);
    107129    }
    108130
     
    140162
    141163    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
     164      if (classValues == null) throw new InvalidOperationException("No class values are defined.");
    142165      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
    143166
Note: See TracChangeset for help on using the changeset viewer.