Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NeuralNetwork/NeuralNetworkEnsembleClassification.cs @ 6580

Last change on this file since 6580 was 6580, checked in by gkronber, 13 years ago

#1474: added implementations for regression and classification with neural network ensembles (wrappers for alglib).

File size: 11.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Optimization;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32using HeuristicLab.Problems.DataAnalysis.Symbolic;
33using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
34using HeuristicLab.Parameters;
35
36namespace HeuristicLab.Algorithms.DataAnalysis {
37  /// <summary>
38  /// Neural network ensemble classification data analysis algorithm.
39  /// </summary>
40  [Item("Neural Network Ensemble Classification", "Neural network ensemble classification data analysis algorithm (wrapper for ALGLIB). Further documentation: http://www.alglib.net/dataanalysis/mlpensembles.php")]
41  [Creatable("Data Analysis")]
42  [StorableClass]
43  public sealed class NeuralNetworkEnsembleClassification : FixedDataAnalysisAlgorithm<IClassificationProblem> {
44    private const string EnsembleSizeParameterName = "EnsembleSize";
45    private const string DecayParameterName = "Decay";
46    private const string HiddenLayersParameterName = "HiddenLayers";
47    private const string NodesInFirstHiddenLayerParameterName = "NodesInFirstHiddenLayer";
48    private const string NodesInSecondHiddenLayerParameterName = "NodesInSecondHiddenLayer";
49    private const string RestartsParameterName = "Restarts";
50    private const string NeuralNetworkEnsembleClassificationModelResultName = "Neural network ensemble classification solution";
51
52    #region parameter properties
53    public IFixedValueParameter<IntValue> EnsembleSizeParameter {
54      get { return (IFixedValueParameter<IntValue>)Parameters[EnsembleSizeParameterName]; }
55    }
56    public IFixedValueParameter<DoubleValue> DecayParameter {
57      get { return (IFixedValueParameter<DoubleValue>)Parameters[DecayParameterName]; }
58    }
59    public ConstrainedValueParameter<IntValue> HiddenLayersParameter {
60      get { return (ConstrainedValueParameter<IntValue>)Parameters[HiddenLayersParameterName]; }
61    }
62    public IFixedValueParameter<IntValue> NodesInFirstHiddenLayerParameter {
63      get { return (IFixedValueParameter<IntValue>)Parameters[NodesInFirstHiddenLayerParameterName]; }
64    }
65    public IFixedValueParameter<IntValue> NodesInSecondHiddenLayerParameter {
66      get { return (IFixedValueParameter<IntValue>)Parameters[NodesInSecondHiddenLayerParameterName]; }
67    }
68    public IFixedValueParameter<IntValue> RestartsParameter {
69      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
70    }
71    #endregion
72
73    #region properties
74    public int EnsembleSize {
75      get { return EnsembleSizeParameter.Value.Value; }
76      set {
77        if (value < 1) throw new ArgumentException("The number of models in the ensemble must be positive and at least one.", "EnsembleSize");
78        EnsembleSizeParameter.Value.Value = value;
79      }
80    }
81    public double Decay {
82      get { return DecayParameter.Value.Value; }
83      set {
84        if (value < 0.001 || value > 100) throw new ArgumentException("The decay parameter should be set to a value between 0.001 and 100.", "Decay");
85        DecayParameter.Value.Value = value;
86      }
87    }
88    public int HiddenLayers {
89      get { return HiddenLayersParameter.Value.Value; }
90      set {
91        if (value < 0 || value > 2) throw new ArgumentException("The number of hidden layers should be set to 0, 1, or 2.", "HiddenLayers");
92        HiddenLayersParameter.Value = (from v in HiddenLayersParameter.ValidValues
93                                       where v.Value == value
94                                       select v)
95                                      .Single();
96      }
97    }
98    public int NodesInFirstHiddenLayer {
99      get { return NodesInFirstHiddenLayerParameter.Value.Value; }
100      set {
101        if (value < 1) throw new ArgumentException("The number of nodes in the first hidden layer must be at least one.", "NodesInFirstHiddenLayer");
102        NodesInFirstHiddenLayerParameter.Value.Value = value;
103      }
104    }
105    public int NodesInSecondHiddenLayer {
106      get { return NodesInSecondHiddenLayerParameter.Value.Value; }
107      set {
108        if (value < 1) throw new ArgumentException("The number of nodes in the first second layer must be at least one.", "NodesInSecondHiddenLayer");
109        NodesInSecondHiddenLayerParameter.Value.Value = value;
110      }
111    }
112    public int Restarts {
113      get { return RestartsParameter.Value.Value; }
114      set {
115        if (value < 0) throw new ArgumentException("The number of restarts must be positive.", "Restarts");
116        RestartsParameter.Value.Value = value;
117      }
118    }
119    #endregion
120
121
122    [StorableConstructor]
123    private NeuralNetworkEnsembleClassification(bool deserializing) : base(deserializing) { }
124    private NeuralNetworkEnsembleClassification(NeuralNetworkEnsembleClassification original, Cloner cloner)
125      : base(original, cloner) {
126    }
127    public NeuralNetworkEnsembleClassification()
128      : base() {
129      var validHiddenLayerValues = new ItemSet<IntValue>(new IntValue[] { new IntValue(0), new IntValue(1), new IntValue(2) });
130      var selectedHiddenLayerValue = (from v in validHiddenLayerValues
131                                      where v.Value == 1
132                                      select v)
133                                     .Single();
134      Parameters.Add(new FixedValueParameter<IntValue>(EnsembleSizeParameterName, "The number of simple neural network models in the ensemble. A good value is 10.", new IntValue(10)));
135      Parameters.Add(new FixedValueParameter<DoubleValue>(DecayParameterName, "The decay parameter for the training phase of the neural network. This parameter determines the strengh of regularization and should be set to a value between 0.001 (weak regularization) to 100 (very strong regularization). The correct value should be determined via cross-validation.", new DoubleValue(0.001)));
136      Parameters.Add(new ConstrainedValueParameter<IntValue>(HiddenLayersParameterName, "The number of hidden layers for the neural network (0, 1, or 2)", validHiddenLayerValues, selectedHiddenLayerValue));
137      Parameters.Add(new FixedValueParameter<IntValue>(NodesInFirstHiddenLayerParameterName, "The number of nodes in the first hidden layer. The value should be rather large (30-100 nodes) in order to make the network highly flexible and run into the early stopping criterion). This value is not used if the number of hidden layers is zero.", new IntValue(100)));
138      Parameters.Add(new FixedValueParameter<IntValue>(NodesInSecondHiddenLayerParameterName, "The number of nodes in the second hidden layer. This value is not used if the number of hidden layers is zero or one.", new IntValue(100)));
139      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of restarts for learning.", new IntValue(2)));
140
141      Problem = new ClassificationProblem();
142    }
143    [StorableHook(HookType.AfterDeserialization)]
144    private void AfterDeserialization() { }
145
146    public override IDeepCloneable Clone(Cloner cloner) {
147      return new NeuralNetworkEnsembleClassification(this, cloner);
148    }
149
150    #region neural network ensemble
151    protected override void Run() {
152      double rmsError, avgRelError, relClassError;
153      var solution = CreateNeuralNetworkEnsembleClassificationSolution(Problem.ProblemData, EnsembleSize, HiddenLayers, NodesInFirstHiddenLayer, NodesInSecondHiddenLayer, Decay, Restarts, out rmsError, out avgRelError, out relClassError);
154      Results.Add(new Result(NeuralNetworkEnsembleClassificationModelResultName, "The neural network ensemble classification solution.", solution));
155      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the neural network ensemble regression solution on the training set.", new DoubleValue(rmsError)));
156      Results.Add(new Result("Average relative error", "The average of relative errors of the neural network ensemble regression solution on the training set.", new PercentValue(avgRelError)));
157      Results.Add(new Result("Relative classification error", "The percentage of misclassified samples.", new PercentValue(relClassError)));
158    }
159
160    public static IClassificationSolution CreateNeuralNetworkEnsembleClassificationSolution(IClassificationProblemData problemData, int ensembleSize, int nLayers, int nHiddenNodes1, int nHiddenNodes2, double decay, int restarts,
161      out double rmsError, out double avgRelError, out double relClassError) {
162      Dataset dataset = problemData.Dataset;
163      string targetVariable = problemData.TargetVariable;
164      IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
165      IEnumerable<int> rows = problemData.TrainingIndizes;
166      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
167      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
168        throw new NotSupportedException("Neural network ensemble classification does not support NaN or infinity values in the input dataset.");
169
170      int nRows = inputMatrix.GetLength(0);
171      int nFeatures = inputMatrix.GetLength(1) - 1;
172      double[] classValues = dataset.GetVariableValues(targetVariable).Distinct().OrderBy(x => x).ToArray();
173      int nClasses = classValues.Count();
174      // map original class values to values [0..nClasses-1]
175      Dictionary<double, double> classIndizes = new Dictionary<double, double>();
176      for (int i = 0; i < nClasses; i++) {
177        classIndizes[classValues[i]] = i;
178      }
179      for (int row = 0; row < nRows; row++) {
180        inputMatrix[row, nFeatures] = classIndizes[inputMatrix[row, nFeatures]];
181      }
182
183      alglib.mlpensemble mlpEnsemble = null;
184      if (nLayers == 0) {
185        alglib.mlpecreatec0(allowedInputVariables.Count(), nClasses, ensembleSize, out mlpEnsemble);
186      } else if (nLayers == 1) {
187        alglib.mlpecreatec1(allowedInputVariables.Count(), nHiddenNodes1, nClasses, ensembleSize, out mlpEnsemble);
188      } else if (nLayers == 2) {
189        alglib.mlpecreatec2(allowedInputVariables.Count(), nHiddenNodes1, nHiddenNodes2, nClasses, ensembleSize, out mlpEnsemble);
190      } else throw new ArgumentException("Number of layers must be zero, one, or two.", "nLayers");
191      alglib.mlpreport rep;
192
193      int info;
194      alglib.mlpetraines(mlpEnsemble, inputMatrix, nRows, decay, restarts, out info, out rep);
195      if (info != 6) throw new ArgumentException("Error in calculation of neural network ensemble regression solution");
196
197      rmsError = alglib.mlpermserror(mlpEnsemble, inputMatrix, nRows);
198      avgRelError = alglib.mlpeavgrelerror(mlpEnsemble, inputMatrix, nRows);
199      relClassError = alglib.mlperelclserror(mlpEnsemble, inputMatrix, nRows);
200
201      return new NeuralNetworkEnsembleClassificationSolution(problemData, new NeuralNetworkEnsembleModel(mlpEnsemble, targetVariable, allowedInputVariables, problemData.ClassValues.ToArray()));
202    }
203    #endregion
204  }
205}
Note: See TracBrowser for help on using the repository browser.