Free cookie consent management tool by TermsFeed Policy Generator

Changeset 6583 for trunk/sources


Ignore:
Timestamp:
07/21/11 16:28:00 (13 years ago)
Author:
gkronber
Message:

#763 added first implementation of classification and regression based on k nearest neighbour.

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4
Files:
4 added
1 edited
5 copied

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r6580 r6583  
    112112    <Compile Include="HeuristicLabAlgorithmsDataAnalysisPlugin.cs" />
    113113    <Compile Include="FixedDataAnalysisAlgorithm.cs" />
     114    <Compile Include="Interfaces\INearestNeighbourClassificationSolution.cs" />
     115    <Compile Include="Interfaces\INearestNeighbourRegressionSolution.cs" />
     116    <Compile Include="Interfaces\INearestNeighbourModel.cs" />
    114117    <Compile Include="Interfaces\INeuralNetworkEnsembleClassificationSolution.cs" />
    115118    <Compile Include="Interfaces\INeuralNetworkEnsembleRegressionSolution.cs" />
     
    138141    <Compile Include="Linear\MultinomialLogitClassificationSolution.cs" />
    139142    <Compile Include="Linear\MultinomialLogitModel.cs" />
     143    <Compile Include="NearestNeighbour\NearestNeighbourClassification.cs" />
     144    <Compile Include="NearestNeighbour\NearestNeighbourClassificationSolution.cs" />
     145    <Compile Include="NearestNeighbour\NearestNeighbourModel.cs" />
     146    <Compile Include="NearestNeighbour\NearestNeighbourRegression.cs" />
     147    <Compile Include="NearestNeighbour\NearestNeighbourRegressionSolution.cs" />
    140148    <Compile Include="NeuralNetwork\NeuralNetworkEnsembleClassification.cs" />
    141149    <Compile Include="NeuralNetwork\NeuralNetworkEnsembleClassificationSolution.cs" />
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/Interfaces/INearestNeighbourClassificationSolution.cs

    r6577 r6583  
    2626namespace HeuristicLab.Algorithms.DataAnalysis {
    2727  /// <summary>
    28   /// Interface to represent a neural network regression solution
     28  /// Interface to represent a nearest neighbour classification solution
    2929  /// </summary>
    30   public interface INeuralNetworkRegressionSolution : IRegressionSolution {
    31     new INeuralNetworkModel Model { get; }
     30  public interface INearestNeighbourClassificationSolution : IClassificationSolution {
     31    new INearestNeighbourModel Model { get; }
    3232  }
    3333}
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/Interfaces/INearestNeighbourModel.cs

    r6577 r6583  
    2727namespace HeuristicLab.Algorithms.DataAnalysis {
    2828  /// <summary>
    29   /// Interface to represent a random forest model for either regression or classification
     29  /// Interface to represent a nearest neighbour model for either regression or classification
    3030  /// </summary>
    31   public interface IRandomForestModel : IRegressionModel, IClassificationModel {
     31  public interface INearestNeighbourModel : IRegressionModel, IClassificationModel {
    3232  }
    3333}
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/Interfaces/INearestNeighbourRegressionSolution.cs

    r6577 r6583  
    2626namespace HeuristicLab.Algorithms.DataAnalysis {
    2727  /// <summary>
    28   /// Interface to represent a neural network regression solution
     28  /// Interface to represent a nearest neighbour regression solution
    2929  /// </summary>
    30   public interface INeuralNetworkRegressionSolution : IRegressionSolution {
    31     new INeuralNetworkModel Model { get; }
     30  public interface INearestNeighbourRegressionSolution : IRegressionSolution {
     31    new INearestNeighbourModel Model { get; }
    3232  }
    3333}
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourClassification.cs

    r6580 r6583  
    3636namespace HeuristicLab.Algorithms.DataAnalysis {
    3737  /// <summary>
    38   /// Neural network regression data analysis algorithm.
     38  /// Nearest neighbour classification data analysis algorithm.
    3939  /// </summary>
    40   [Item("Neural Network Regression", "Neural network regression data analysis algorithm (wrapper for ALGLIB). Further documentation: http://www.alglib.net/dataanalysis/neuralnetworks.php")]
     40  [Item("Nearest Neighbour Classification", "Nearest neighbour classification data analysis algorithm (wrapper for ALGLIB).")]
    4141  [Creatable("Data Analysis")]
    4242  [StorableClass]
    43   public sealed class NeuralNetworkRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
    44     private const string DecayParameterName = "Decay";
    45     private const string HiddenLayersParameterName = "HiddenLayers";
    46     private const string NodesInFirstHiddenLayerParameterName = "NodesInFirstHiddenLayer";
    47     private const string NodesInSecondHiddenLayerParameterName = "NodesInSecondHiddenLayer";
    48     private const string RestartsParameterName = "Restarts";
    49     private const string NeuralNetworkRegressionModelResultName = "Neural network regression solution";
     43  public sealed class NearestNeighbourClassification : FixedDataAnalysisAlgorithm<IClassificationProblem> {
     44    private const string KParameterName = "K";
     45    private const string NearestNeighbourClassificationModelResultName = "Nearest neighbour classification solution";
    5046
    5147    #region parameter properties
    52     public IFixedValueParameter<DoubleValue> DecayParameter {
    53       get { return (IFixedValueParameter<DoubleValue>)Parameters[DecayParameterName]; }
    54     }
    55     public ConstrainedValueParameter<IntValue> HiddenLayersParameter {
    56       get { return (ConstrainedValueParameter<IntValue>)Parameters[HiddenLayersParameterName]; }
    57     }
    58     public IFixedValueParameter<IntValue> NodesInFirstHiddenLayerParameter {
    59       get { return (IFixedValueParameter<IntValue>)Parameters[NodesInFirstHiddenLayerParameterName]; }
    60     }
    61     public IFixedValueParameter<IntValue> NodesInSecondHiddenLayerParameter {
    62       get { return (IFixedValueParameter<IntValue>)Parameters[NodesInSecondHiddenLayerParameterName]; }
    63     }
    64     public IFixedValueParameter<IntValue> RestartsParameter {
    65       get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
     48    public IFixedValueParameter<IntValue> KParameter {
     49      get { return (IFixedValueParameter<IntValue>)Parameters[KParameterName]; }
    6650    }
    6751    #endregion
    68 
    6952    #region properties
    70     public double Decay {
    71       get { return DecayParameter.Value.Value; }
     53    public int K {
     54      get { return KParameter.Value.Value; }
    7255      set {
    73         if (value < 0.001 || value > 100) throw new ArgumentException("The decay parameter should be set to a value between 0.001 and 100.", "Decay");
    74         DecayParameter.Value.Value = value;
    75       }
    76     }
    77     public int HiddenLayers {
    78       get { return HiddenLayersParameter.Value.Value; }
    79       set {
    80         if (value < 0 || value > 2) throw new ArgumentException("The number of hidden layers should be set to 0, 1, or 2.", "HiddenLayers");
    81         HiddenLayersParameter.Value = (from v in HiddenLayersParameter.ValidValues
    82                                        where v.Value == value
    83                                        select v)
    84                                       .Single();
    85       }
    86     }
    87     public int NodesInFirstHiddenLayer {
    88       get { return NodesInFirstHiddenLayerParameter.Value.Value; }
    89       set {
    90         if (value < 1) throw new ArgumentException("The number of nodes in the first hidden layer must be at least one.", "NodesInFirstHiddenLayer");
    91         NodesInFirstHiddenLayerParameter.Value.Value = value;
    92       }
    93     }
    94     public int NodesInSecondHiddenLayer {
    95       get { return NodesInSecondHiddenLayerParameter.Value.Value; }
    96       set {
    97         if (value < 1) throw new ArgumentException("The number of nodes in the first second layer must be at least one.", "NodesInSecondHiddenLayer");
    98         NodesInSecondHiddenLayerParameter.Value.Value = value;
    99       }
    100     }
    101     public int Restarts {
    102       get { return RestartsParameter.Value.Value; }
    103       set {
    104         if (value < 0) throw new ArgumentException("The number of restarts must be positive.", "Restarts");
    105         RestartsParameter.Value.Value = value;
     56        if (value <= 0) throw new ArgumentException("K must be larger than zero.", "K");
     57        else KParameter.Value.Value = value;
    10658      }
    10759    }
    10860    #endregion
    10961
    110 
    11162    [StorableConstructor]
    112     private NeuralNetworkRegression(bool deserializing) : base(deserializing) { }
    113     private NeuralNetworkRegression(NeuralNetworkRegression original, Cloner cloner)
     63    private NearestNeighbourClassification(bool deserializing) : base(deserializing) { }
     64    private NearestNeighbourClassification(NearestNeighbourClassification original, Cloner cloner)
    11465      : base(original, cloner) {
    11566    }
    116     public NeuralNetworkRegression()
     67    public NearestNeighbourClassification()
    11768      : base() {
    118       var validHiddenLayerValues = new ItemSet<IntValue>(new IntValue[] { new IntValue(0), new IntValue(1), new IntValue(2) });
    119       var selectedHiddenLayerValue = (from v in validHiddenLayerValues
    120                                       where v.Value == 1
    121                                       select v)
    122                                      .Single();
    123       Parameters.Add(new FixedValueParameter<DoubleValue>(DecayParameterName, "The decay parameter for the training phase of the neural network. This parameter determines the strengh of regularization and should be set to a value between 0.001 (weak regularization) to 100 (very strong regularization). The correct value should be determined via cross-validation.", new DoubleValue(1)));
    124       Parameters.Add(new ConstrainedValueParameter<IntValue>(HiddenLayersParameterName, "The number of hidden layers for the neural network (0, 1, or 2)", validHiddenLayerValues, selectedHiddenLayerValue));
    125       Parameters.Add(new FixedValueParameter<IntValue>(NodesInFirstHiddenLayerParameterName, "The number of nodes in the first hidden layer. This value is not used if the number of hidden layers is zero.", new IntValue(10)));
    126       Parameters.Add(new FixedValueParameter<IntValue>(NodesInSecondHiddenLayerParameterName, "The number of nodes in the second hidden layer. This value is not used if the number of hidden layers is zero or one.", new IntValue(10)));
    127       Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of restarts for learning.", new IntValue(2)));
    128 
    129       Problem = new RegressionProblem();
     69      Parameters.Add(new FixedValueParameter<IntValue>(KParameterName, "The number of nearest neighbours to consider for regression.", new IntValue(3)));
     70      Problem = new ClassificationProblem();
    13071    }
    13172    [StorableHook(HookType.AfterDeserialization)]
     
    13374
    13475    public override IDeepCloneable Clone(Cloner cloner) {
    135       return new NeuralNetworkRegression(this, cloner);
     76      return new NearestNeighbourClassification(this, cloner);
    13677    }
    13778
    138     #region neural network
     79    #region nearest neighbour
    13980    protected override void Run() {
    140       double rmsError, avgRelError;
    141       var solution = CreateNeuralNetworkRegressionSolution(Problem.ProblemData, HiddenLayers, NodesInFirstHiddenLayer, NodesInSecondHiddenLayer, Decay, Restarts, out rmsError, out avgRelError);
    142       Results.Add(new Result(NeuralNetworkRegressionModelResultName, "The neural network regression solution.", solution));
    143       Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the neural network regression solution on the training set.", new DoubleValue(rmsError)));
    144       Results.Add(new Result("Average relative error", "The average of relative errors of the neural network regression solution on the training set.", new PercentValue(avgRelError)));
     81      var solution = CreateNearestNeighbourClassificationSolution(Problem.ProblemData, K);
     82      Results.Add(new Result(NearestNeighbourClassificationModelResultName, "The nearest neighbour classification solution.", solution));
    14583    }
    14684
    147     public static IRegressionSolution CreateNeuralNetworkRegressionSolution(IRegressionProblemData problemData, int nLayers, int nHiddenNodes1, int nHiddenNodes2, double decay, int restarts,
    148       out double rmsError, out double avgRelError) {
     85    public static IClassificationSolution CreateNearestNeighbourClassificationSolution(IClassificationProblemData problemData, int k) {
    14986      Dataset dataset = problemData.Dataset;
    15087      string targetVariable = problemData.TargetVariable;
     
    15390      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
    15491      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    155         throw new NotSupportedException("Neural network regression does not support NaN or infinity values in the input dataset.");
     92        throw new NotSupportedException("Nearest neighbour classification does not support NaN or infinity values in the input dataset.");
    15693
    157       double targetMin = problemData.Dataset.GetEnumeratedVariableValues(targetVariable).Min();
    158       targetMin = targetMin - targetMin * 0.1; // -10%
    159       double targetMax = problemData.Dataset.GetEnumeratedVariableValues(targetVariable).Max();
    160       targetMax = targetMax + targetMax * 0.1; // + 10%
     94      alglib.nearestneighbor.kdtree kdtree = new alglib.nearestneighbor.kdtree();
    16195
    162       alglib.multilayerperceptron multiLayerPerceptron = null;
    163       if (nLayers == 0) {
    164         alglib.mlpcreater0(allowedInputVariables.Count(), 1, targetMin, targetMax, out multiLayerPerceptron);
    165       } else if (nLayers == 1) {
    166         alglib.mlpcreater1(allowedInputVariables.Count(), nHiddenNodes1, 1, targetMin, targetMax, out multiLayerPerceptron);
    167       } else if (nLayers == 2) {
    168         alglib.mlpcreater2(allowedInputVariables.Count(), nHiddenNodes1, nHiddenNodes2, 1, targetMin, targetMax, out multiLayerPerceptron);
    169       } else throw new ArgumentException("Number of layers must be zero, one, or two.", "nLayers");
    170       alglib.mlpreport rep;
    17196      int nRows = inputMatrix.GetLength(0);
     97      int nFeatures = inputMatrix.GetLength(1) - 1;
     98      double[] classValues = dataset.GetVariableValues(targetVariable).Distinct().OrderBy(x => x).ToArray();
     99      int nClasses = classValues.Count();
     100      // map original class values to values [0..nClasses-1]
     101      Dictionary<double, double> classIndizes = new Dictionary<double, double>();
     102      for (int i = 0; i < nClasses; i++) {
     103        classIndizes[classValues[i]] = i;
     104      }
     105      for (int row = 0; row < nRows; row++) {
     106        inputMatrix[row, nFeatures] = classIndizes[inputMatrix[row, nFeatures]];
     107      }
     108      alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdtree);
    172109
    173       int info;
    174       // using mlptrainlm instead of mlptraines or mlptrainbfgs because only one parameter is necessary
    175       alglib.mlptrainlm(multiLayerPerceptron, inputMatrix, nRows, decay, restarts, out info, out rep);
    176       if (info != 2) throw new ArgumentException("Error in calculation of neural network regression solution");
    177 
    178       rmsError = alglib.mlprmserror(multiLayerPerceptron, inputMatrix, nRows);
    179       avgRelError = alglib.mlpavgrelerror(multiLayerPerceptron, inputMatrix, nRows);     
    180 
    181       return new NeuralNetworkRegressionSolution(problemData, new NeuralNetworkModel(multiLayerPerceptron, targetVariable, allowedInputVariables));
     110      return new NearestNeighbourClassificationSolution(problemData, new NearestNeighbourModel(kdtree, k, targetVariable, allowedInputVariables, problemData.ClassValues.ToArray()));
    182111    }
    183112    #endregion
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourRegression.cs

    r6580 r6583  
    3636namespace HeuristicLab.Algorithms.DataAnalysis {
    3737  /// <summary>
    38   /// Neural network regression data analysis algorithm.
     38  /// Nearest neighbour regression data analysis algorithm.
    3939  /// </summary>
    40   [Item("Neural Network Regression", "Neural network regression data analysis algorithm (wrapper for ALGLIB). Further documentation: http://www.alglib.net/dataanalysis/neuralnetworks.php")]
     40  [Item("Nearest Neighbour Regression", "Nearest neighbour regression data analysis algorithm (wrapper for ALGLIB).")]
    4141  [Creatable("Data Analysis")]
    4242  [StorableClass]
    43   public sealed class NeuralNetworkRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
    44     private const string DecayParameterName = "Decay";
    45     private const string HiddenLayersParameterName = "HiddenLayers";
    46     private const string NodesInFirstHiddenLayerParameterName = "NodesInFirstHiddenLayer";
    47     private const string NodesInSecondHiddenLayerParameterName = "NodesInSecondHiddenLayer";
    48     private const string RestartsParameterName = "Restarts";
    49     private const string NeuralNetworkRegressionModelResultName = "Neural network regression solution";
     43  public sealed class NearestNeighbourRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
     44    private const string KParameterName = "K";
     45    private const string NearestNeighbourRegressionModelResultName = "Nearest neighbour regression solution";
    5046
    5147    #region parameter properties
    52     public IFixedValueParameter<DoubleValue> DecayParameter {
    53       get { return (IFixedValueParameter<DoubleValue>)Parameters[DecayParameterName]; }
    54     }
    55     public ConstrainedValueParameter<IntValue> HiddenLayersParameter {
    56       get { return (ConstrainedValueParameter<IntValue>)Parameters[HiddenLayersParameterName]; }
    57     }
    58     public IFixedValueParameter<IntValue> NodesInFirstHiddenLayerParameter {
    59       get { return (IFixedValueParameter<IntValue>)Parameters[NodesInFirstHiddenLayerParameterName]; }
    60     }
    61     public IFixedValueParameter<IntValue> NodesInSecondHiddenLayerParameter {
    62       get { return (IFixedValueParameter<IntValue>)Parameters[NodesInSecondHiddenLayerParameterName]; }
    63     }
    64     public IFixedValueParameter<IntValue> RestartsParameter {
    65       get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
     48    public IFixedValueParameter<IntValue> KParameter {
     49      get { return (IFixedValueParameter<IntValue>)Parameters[KParameterName]; }
    6650    }
    6751    #endregion
    68 
    6952    #region properties
    70     public double Decay {
    71       get { return DecayParameter.Value.Value; }
     53    public int K {
     54      get { return KParameter.Value.Value; }
    7255      set {
    73         if (value < 0.001 || value > 100) throw new ArgumentException("The decay parameter should be set to a value between 0.001 and 100.", "Decay");
    74         DecayParameter.Value.Value = value;
    75       }
    76     }
    77     public int HiddenLayers {
    78       get { return HiddenLayersParameter.Value.Value; }
    79       set {
    80         if (value < 0 || value > 2) throw new ArgumentException("The number of hidden layers should be set to 0, 1, or 2.", "HiddenLayers");
    81         HiddenLayersParameter.Value = (from v in HiddenLayersParameter.ValidValues
    82                                        where v.Value == value
    83                                        select v)
    84                                       .Single();
    85       }
    86     }
    87     public int NodesInFirstHiddenLayer {
    88       get { return NodesInFirstHiddenLayerParameter.Value.Value; }
    89       set {
    90         if (value < 1) throw new ArgumentException("The number of nodes in the first hidden layer must be at least one.", "NodesInFirstHiddenLayer");
    91         NodesInFirstHiddenLayerParameter.Value.Value = value;
    92       }
    93     }
    94     public int NodesInSecondHiddenLayer {
    95       get { return NodesInSecondHiddenLayerParameter.Value.Value; }
    96       set {
    97         if (value < 1) throw new ArgumentException("The number of nodes in the first second layer must be at least one.", "NodesInSecondHiddenLayer");
    98         NodesInSecondHiddenLayerParameter.Value.Value = value;
    99       }
    100     }
    101     public int Restarts {
    102       get { return RestartsParameter.Value.Value; }
    103       set {
    104         if (value < 0) throw new ArgumentException("The number of restarts must be positive.", "Restarts");
    105         RestartsParameter.Value.Value = value;
     56        if (value <= 0) throw new ArgumentException("K must be larger than zero.", "K");
     57        else KParameter.Value.Value = value;
    10658      }
    10759    }
    10860    #endregion
    10961
    110 
    11162    [StorableConstructor]
    112     private NeuralNetworkRegression(bool deserializing) : base(deserializing) { }
    113     private NeuralNetworkRegression(NeuralNetworkRegression original, Cloner cloner)
     63    private NearestNeighbourRegression(bool deserializing) : base(deserializing) { }
     64    private NearestNeighbourRegression(NearestNeighbourRegression original, Cloner cloner)
    11465      : base(original, cloner) {
    11566    }
    116     public NeuralNetworkRegression()
     67    public NearestNeighbourRegression()
    11768      : base() {
    118       var validHiddenLayerValues = new ItemSet<IntValue>(new IntValue[] { new IntValue(0), new IntValue(1), new IntValue(2) });
    119       var selectedHiddenLayerValue = (from v in validHiddenLayerValues
    120                                       where v.Value == 1
    121                                       select v)
    122                                      .Single();
    123       Parameters.Add(new FixedValueParameter<DoubleValue>(DecayParameterName, "The decay parameter for the training phase of the neural network. This parameter determines the strengh of regularization and should be set to a value between 0.001 (weak regularization) to 100 (very strong regularization). The correct value should be determined via cross-validation.", new DoubleValue(1)));
    124       Parameters.Add(new ConstrainedValueParameter<IntValue>(HiddenLayersParameterName, "The number of hidden layers for the neural network (0, 1, or 2)", validHiddenLayerValues, selectedHiddenLayerValue));
    125       Parameters.Add(new FixedValueParameter<IntValue>(NodesInFirstHiddenLayerParameterName, "The number of nodes in the first hidden layer. This value is not used if the number of hidden layers is zero.", new IntValue(10)));
    126       Parameters.Add(new FixedValueParameter<IntValue>(NodesInSecondHiddenLayerParameterName, "The number of nodes in the second hidden layer. This value is not used if the number of hidden layers is zero or one.", new IntValue(10)));
    127       Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of restarts for learning.", new IntValue(2)));
    128 
     69      Parameters.Add(new FixedValueParameter<IntValue>(KParameterName, "The number of nearest neighbours to consider for regression.", new IntValue(3)));
    12970      Problem = new RegressionProblem();
    13071    }
     
    13374
    13475    public override IDeepCloneable Clone(Cloner cloner) {
    135       return new NeuralNetworkRegression(this, cloner);
     76      return new NearestNeighbourRegression(this, cloner);
    13677    }
    13778
    138     #region neural network
     79    #region nearest neighbour
    13980    protected override void Run() {
    140       double rmsError, avgRelError;
    141       var solution = CreateNeuralNetworkRegressionSolution(Problem.ProblemData, HiddenLayers, NodesInFirstHiddenLayer, NodesInSecondHiddenLayer, Decay, Restarts, out rmsError, out avgRelError);
    142       Results.Add(new Result(NeuralNetworkRegressionModelResultName, "The neural network regression solution.", solution));
    143       Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the neural network regression solution on the training set.", new DoubleValue(rmsError)));
    144       Results.Add(new Result("Average relative error", "The average of relative errors of the neural network regression solution on the training set.", new PercentValue(avgRelError)));
     81      var solution = CreateNearestNeighbourRegressionSolution(Problem.ProblemData, K);
     82      Results.Add(new Result(NearestNeighbourRegressionModelResultName, "The nearest neighbour regression solution.", solution));
    14583    }
    14684
    147     public static IRegressionSolution CreateNeuralNetworkRegressionSolution(IRegressionProblemData problemData, int nLayers, int nHiddenNodes1, int nHiddenNodes2, double decay, int restarts,
    148       out double rmsError, out double avgRelError) {
     85    public static IRegressionSolution CreateNearestNeighbourRegressionSolution(IRegressionProblemData problemData, int k) {
    14986      Dataset dataset = problemData.Dataset;
    15087      string targetVariable = problemData.TargetVariable;
     
    15390      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
    15491      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    155         throw new NotSupportedException("Neural network regression does not support NaN or infinity values in the input dataset.");
     92        throw new NotSupportedException("Nearest neighbour regression does not support NaN or infinity values in the input dataset.");
    15693
    157       double targetMin = problemData.Dataset.GetEnumeratedVariableValues(targetVariable).Min();
    158       targetMin = targetMin - targetMin * 0.1; // -10%
    159       double targetMax = problemData.Dataset.GetEnumeratedVariableValues(targetVariable).Max();
    160       targetMax = targetMax + targetMax * 0.1; // + 10%
     94      alglib.nearestneighbor.kdtree kdtree = new alglib.nearestneighbor.kdtree();
    16195
    162       alglib.multilayerperceptron multiLayerPerceptron = null;
    163       if (nLayers == 0) {
    164         alglib.mlpcreater0(allowedInputVariables.Count(), 1, targetMin, targetMax, out multiLayerPerceptron);
    165       } else if (nLayers == 1) {
    166         alglib.mlpcreater1(allowedInputVariables.Count(), nHiddenNodes1, 1, targetMin, targetMax, out multiLayerPerceptron);
    167       } else if (nLayers == 2) {
    168         alglib.mlpcreater2(allowedInputVariables.Count(), nHiddenNodes1, nHiddenNodes2, 1, targetMin, targetMax, out multiLayerPerceptron);
    169       } else throw new ArgumentException("Number of layers must be zero, one, or two.", "nLayers");
    170       alglib.mlpreport rep;
    17196      int nRows = inputMatrix.GetLength(0);
    17297
    173       int info;
    174       // using mlptrainlm instead of mlptraines or mlptrainbfgs because only one parameter is necessary
    175       alglib.mlptrainlm(multiLayerPerceptron, inputMatrix, nRows, decay, restarts, out info, out rep);
    176       if (info != 2) throw new ArgumentException("Error in calculation of neural network regression solution");
     98      alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdtree);
    17799
    178       rmsError = alglib.mlprmserror(multiLayerPerceptron, inputMatrix, nRows);
    179       avgRelError = alglib.mlpavgrelerror(multiLayerPerceptron, inputMatrix, nRows);     
    180 
    181       return new NeuralNetworkRegressionSolution(problemData, new NeuralNetworkModel(multiLayerPerceptron, targetVariable, allowedInputVariables));
     100      return new NearestNeighbourRegressionSolution(problemData, new NearestNeighbourModel(kdtree, k, targetVariable, allowedInputVariables));
    182101    }
    183102    #endregion
Note: See TracChangeset for help on using the changeset viewer.