Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/10/12 14:57:21 (12 years ago)
Author:
abeham
Message:

#1913: Changed k-NN to move model representation (kdTree) into the model object

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourClassification.cs

    r8139 r8465  
    2121
    2222using System;
    23 using System.Collections.Generic;
    2423using System.Linq;
    2524using HeuristicLab.Common;
    2625using HeuristicLab.Core;
    2726using HeuristicLab.Data;
    28 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2927using HeuristicLab.Optimization;
     28using HeuristicLab.Parameters;
    3029using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3130using HeuristicLab.Problems.DataAnalysis;
    32 using HeuristicLab.Problems.DataAnalysis.Symbolic;
    33 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    34 using HeuristicLab.Parameters;
    3531
    3632namespace HeuristicLab.Algorithms.DataAnalysis {
     
    8480
    8581    public static IClassificationSolution CreateNearestNeighbourClassificationSolution(IClassificationProblemData problemData, int k) {
    86       Dataset dataset = problemData.Dataset;
    87       string targetVariable = problemData.TargetVariable;
    88       IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
    89       IEnumerable<int> rows = problemData.TrainingIndices;
    90       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
    91       if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    92         throw new NotSupportedException("Nearest neighbour classification does not support NaN or infinity values in the input dataset.");
     82      var problemDataClone = (IClassificationProblemData)problemData.Clone();
     83      return new NearestNeighbourClassificationSolution(problemDataClone, Train(problemDataClone, k));
     84    }
    9385
    94       alglib.nearestneighbor.kdtree kdtree = new alglib.nearestneighbor.kdtree();
    95 
    96       int nRows = inputMatrix.GetLength(0);
    97       int nFeatures = inputMatrix.GetLength(1) - 1;
    98       double[] classValues = dataset.GetDoubleValues(targetVariable).Distinct().OrderBy(x => x).ToArray();
    99       int nClasses = classValues.Count();
    100       // map original class values to values [0..nClasses-1]
    101       Dictionary<double, double> classIndices = new Dictionary<double, double>();
    102       for (int i = 0; i < nClasses; i++) {
    103         classIndices[classValues[i]] = i;
    104       }
    105       for (int row = 0; row < nRows; row++) {
    106         inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
    107       }
    108       alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdtree);
    109       var problemDataClone = (IClassificationProblemData) problemData.Clone();
    110       return new NearestNeighbourClassificationSolution(problemDataClone, new NearestNeighbourModel(kdtree, k, targetVariable, allowedInputVariables, problemDataClone.ClassValues.ToArray()));
     86    public static INearestNeighbourModel Train(IClassificationProblemData problemData, int k) {
     87      return new NearestNeighbourModel(problemData.Dataset,
     88        problemData.TrainingIndices,
     89        k,
     90        problemData.TargetVariable,
     91        problemData.AllowedInputVariables,
     92        problemData.ClassValues.ToArray());
    11193    }
    11294    #endregion
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs

    r7294 r8465  
    3333  /// </summary>
    3434  [StorableClass]
    35   [Item("NearestNeighbourModel", "Represents a neural network for regression and classification.")]
     35  [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")]
    3636  public sealed class NearestNeighbourModel : NamedItem, INearestNeighbourModel {
    3737
     
    5656    [Storable]
    5757    private int k;
     58
    5859    [StorableConstructor]
    5960    private NearestNeighbourModel(bool deserializing)
     
    9596        this.classValues = (double[])original.classValues.Clone();
    9697    }
    97     public NearestNeighbourModel(alglib.nearestneighbor.kdtree kdTree, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
    98       : base() {
    99       this.name = ItemName;
    100       this.description = ItemDescription;
    101       this.kdTree = kdTree;
     98    public NearestNeighbourModel(Dataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null) {
    10299      this.k = k;
    103100      this.targetVariable = targetVariable;
    104101      this.allowedInputVariables = allowedInputVariables.ToArray();
    105       if (classValues != null)
     102
     103      var inputMatrix = AlglibUtil.PrepareInputMatrix(dataset,
     104                                   allowedInputVariables.Concat(new string[] { targetVariable }),
     105                                   rows);
     106
     107      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
     108        throw new NotSupportedException(
     109          "Nearest neighbour classification does not support NaN or infinity values in the input dataset.");
     110
     111      this.kdTree = new alglib.nearestneighbor.kdtree();
     112
     113      var nRows = inputMatrix.GetLength(0);
     114      var nFeatures = inputMatrix.GetLength(1) - 1;
     115
     116      if (classValues != null) {
    106117        this.classValues = (double[])classValues.Clone();
     118        int nClasses = classValues.Length;
     119        // map original class values to values [0..nClasses-1]
     120        var classIndices = new Dictionary<double, double>();
     121        for (int i = 0; i < nClasses; i++)
     122          classIndices[classValues[i]] = i;
     123
     124        for (int row = 0; row < nRows; row++) {
     125          inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
     126        }
     127      }
     128      alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdTree);
    107129    }
    108130
     
    140162
    141163    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
     164      if (classValues == null) throw new InvalidOperationException("No class values are defined.");
    142165      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
    143166
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourRegression.cs

    r8139 r8465  
    2121
    2222using System;
    23 using System.Collections.Generic;
    24 using System.Linq;
    2523using HeuristicLab.Common;
    2624using HeuristicLab.Core;
    2725using HeuristicLab.Data;
    28 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2926using HeuristicLab.Optimization;
     27using HeuristicLab.Parameters;
    3028using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3129using HeuristicLab.Problems.DataAnalysis;
    32 using HeuristicLab.Problems.DataAnalysis.Symbolic;
    33 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    34 using HeuristicLab.Parameters;
    3530
    3631namespace HeuristicLab.Algorithms.DataAnalysis {
     
    8479
    8580    public static IRegressionSolution CreateNearestNeighbourRegressionSolution(IRegressionProblemData problemData, int k) {
    86       Dataset dataset = problemData.Dataset;
    87       string targetVariable = problemData.TargetVariable;
    88       IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
    89       IEnumerable<int> rows = problemData.TrainingIndices;
    90       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
    91       if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    92         throw new NotSupportedException("Nearest neighbour regression does not support NaN or infinity values in the input dataset.");
     81      var clonedProblemData = (IRegressionProblemData)problemData.Clone();
     82      return new NearestNeighbourRegressionSolution(clonedProblemData, Train(problemData, k));
     83    }
    9384
    94       alglib.nearestneighbor.kdtree kdtree = new alglib.nearestneighbor.kdtree();
    95 
    96       int nRows = inputMatrix.GetLength(0);
    97 
    98       alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdtree);
    99 
    100       return new NearestNeighbourRegressionSolution((IRegressionProblemData)problemData.Clone(), new NearestNeighbourModel(kdtree, k, targetVariable, allowedInputVariables));
     85    public static INearestNeighbourModel Train(IRegressionProblemData problemData, int k) {
     86      return new NearestNeighbourModel(problemData.Dataset,
     87        problemData.TrainingIndices,
     88        k,
     89        problemData.TargetVariable,
     90        problemData.AllowedInputVariables);
    10191    }
    10292    #endregion
Note: See TracChangeset for help on using the changeset viewer.