Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/NeuralNetwork/NeuralNetworkEnsembleClassification.cs @ 17717

Last change on this file since 17717 was 17181, checked in by swagner, 5 years ago

#2875: Merged r17180 from trunk to stable

File size: 12.1 KB
RevLine 
[6577]1#region License Information
2/* HeuristicLab
[17181]3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[6577]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
[15061]25using System.Threading;
[6577]26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Optimization;
[10030]30using HeuristicLab.Parameters;
[17097]31using HEAL.Attic;
[6577]32using HeuristicLab.Problems.DataAnalysis;
33
34namespace HeuristicLab.Algorithms.DataAnalysis {
35  /// <summary>
[6580]36  /// Neural network ensemble classification data analysis algorithm.
[6577]37  /// </summary>
[13297]38  [Item("Neural Network Ensemble Classification (NN)", "Neural network ensemble classification data analysis algorithm (wrapper for ALGLIB). Further documentation: http://www.alglib.net/dataanalysis/mlpensembles.php")]
[12708]39  [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 140)]
[17097]40  [StorableType("21B48D73-B907-4710-854A-C549F8C66CFF")]
[6580]41  public sealed class NeuralNetworkEnsembleClassification : FixedDataAnalysisAlgorithm<IClassificationProblem> {
42    private const string EnsembleSizeParameterName = "EnsembleSize";
[6578]43    private const string DecayParameterName = "Decay";
44    private const string HiddenLayersParameterName = "HiddenLayers";
45    private const string NodesInFirstHiddenLayerParameterName = "NodesInFirstHiddenLayer";
46    private const string NodesInSecondHiddenLayerParameterName = "NodesInSecondHiddenLayer";
47    private const string RestartsParameterName = "Restarts";
[6580]48    private const string NeuralNetworkEnsembleClassificationModelResultName = "Neural network ensemble classification solution";
[6578]49
50    #region parameter properties
[6580]51    public IFixedValueParameter<IntValue> EnsembleSizeParameter {
52      get { return (IFixedValueParameter<IntValue>)Parameters[EnsembleSizeParameterName]; }
53    }
[6578]54    public IFixedValueParameter<DoubleValue> DecayParameter {
55      get { return (IFixedValueParameter<DoubleValue>)Parameters[DecayParameterName]; }
56    }
[8121]57    public IConstrainedValueParameter<IntValue> HiddenLayersParameter {
58      get { return (IConstrainedValueParameter<IntValue>)Parameters[HiddenLayersParameterName]; }
[6578]59    }
60    public IFixedValueParameter<IntValue> NodesInFirstHiddenLayerParameter {
61      get { return (IFixedValueParameter<IntValue>)Parameters[NodesInFirstHiddenLayerParameterName]; }
62    }
63    public IFixedValueParameter<IntValue> NodesInSecondHiddenLayerParameter {
64      get { return (IFixedValueParameter<IntValue>)Parameters[NodesInSecondHiddenLayerParameterName]; }
65    }
66    public IFixedValueParameter<IntValue> RestartsParameter {
67      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
68    }
69    #endregion
70
71    #region properties
[6580]72    public int EnsembleSize {
73      get { return EnsembleSizeParameter.Value.Value; }
74      set {
75        if (value < 1) throw new ArgumentException("The number of models in the ensemble must be positive and at least one.", "EnsembleSize");
76        EnsembleSizeParameter.Value.Value = value;
77      }
78    }
[6578]79    public double Decay {
80      get { return DecayParameter.Value.Value; }
81      set {
82        if (value < 0.001 || value > 100) throw new ArgumentException("The decay parameter should be set to a value between 0.001 and 100.", "Decay");
83        DecayParameter.Value.Value = value;
84      }
85    }
86    public int HiddenLayers {
87      get { return HiddenLayersParameter.Value.Value; }
88      set {
89        if (value < 0 || value > 2) throw new ArgumentException("The number of hidden layers should be set to 0, 1, or 2.", "HiddenLayers");
90        HiddenLayersParameter.Value = (from v in HiddenLayersParameter.ValidValues
91                                       where v.Value == value
92                                       select v)
93                                      .Single();
94      }
95    }
96    public int NodesInFirstHiddenLayer {
97      get { return NodesInFirstHiddenLayerParameter.Value.Value; }
98      set {
99        if (value < 1) throw new ArgumentException("The number of nodes in the first hidden layer must be at least one.", "NodesInFirstHiddenLayer");
100        NodesInFirstHiddenLayerParameter.Value.Value = value;
101      }
102    }
103    public int NodesInSecondHiddenLayer {
104      get { return NodesInSecondHiddenLayerParameter.Value.Value; }
105      set {
106        if (value < 1) throw new ArgumentException("The number of nodes in the first second layer must be at least one.", "NodesInSecondHiddenLayer");
107        NodesInSecondHiddenLayerParameter.Value.Value = value;
108      }
109    }
110    public int Restarts {
111      get { return RestartsParameter.Value.Value; }
112      set {
113        if (value < 0) throw new ArgumentException("The number of restarts must be positive.", "Restarts");
114        RestartsParameter.Value.Value = value;
115      }
116    }
117    #endregion
118
119
[6577]120    [StorableConstructor]
[17097]121    private NeuralNetworkEnsembleClassification(StorableConstructorFlag _) : base(_) { }
[6580]122    private NeuralNetworkEnsembleClassification(NeuralNetworkEnsembleClassification original, Cloner cloner)
[6577]123      : base(original, cloner) {
124    }
[6580]125    public NeuralNetworkEnsembleClassification()
[6577]126      : base() {
[15142]127      var validHiddenLayerValues = new ItemSet<IntValue>(new IntValue[] {
128        (IntValue)new IntValue(0).AsReadOnly(),
129        (IntValue)new IntValue(1).AsReadOnly(),
[6720]130        (IntValue)new IntValue(2).AsReadOnly() });
[6578]131      var selectedHiddenLayerValue = (from v in validHiddenLayerValues
132                                      where v.Value == 1
133                                      select v)
134                                     .Single();
[6580]135      Parameters.Add(new FixedValueParameter<IntValue>(EnsembleSizeParameterName, "The number of simple neural network models in the ensemble. A good value is 10.", new IntValue(10)));
136      Parameters.Add(new FixedValueParameter<DoubleValue>(DecayParameterName, "The decay parameter for the training phase of the neural network. This parameter determines the strengh of regularization and should be set to a value between 0.001 (weak regularization) to 100 (very strong regularization). The correct value should be determined via cross-validation.", new DoubleValue(0.001)));
[6578]137      Parameters.Add(new ConstrainedValueParameter<IntValue>(HiddenLayersParameterName, "The number of hidden layers for the neural network (0, 1, or 2)", validHiddenLayerValues, selectedHiddenLayerValue));
[6580]138      Parameters.Add(new FixedValueParameter<IntValue>(NodesInFirstHiddenLayerParameterName, "The number of nodes in the first hidden layer. The value should be rather large (30-100 nodes) in order to make the network highly flexible and run into the early stopping criterion). This value is not used if the number of hidden layers is zero.", new IntValue(100)));
139      Parameters.Add(new FixedValueParameter<IntValue>(NodesInSecondHiddenLayerParameterName, "The number of nodes in the second hidden layer. This value is not used if the number of hidden layers is zero or one.", new IntValue(100)));
[6578]140      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of restarts for learning.", new IntValue(2)));
141
[6720]142      HiddenLayersParameter.Hidden = true;
143      NodesInFirstHiddenLayerParameter.Hidden = true;
144      NodesInSecondHiddenLayerParameter.Hidden = true;
145      RestartsParameter.Hidden = true;
146
[6580]147      Problem = new ClassificationProblem();
[6577]148    }
149    [StorableHook(HookType.AfterDeserialization)]
150    private void AfterDeserialization() { }
151
152    public override IDeepCloneable Clone(Cloner cloner) {
[6580]153      return new NeuralNetworkEnsembleClassification(this, cloner);
[6577]154    }
155
[6580]156    #region neural network ensemble
[15061]157    protected override void Run(CancellationToken cancellationToken) {
[6580]158      double rmsError, avgRelError, relClassError;
159      var solution = CreateNeuralNetworkEnsembleClassificationSolution(Problem.ProblemData, EnsembleSize, HiddenLayers, NodesInFirstHiddenLayer, NodesInSecondHiddenLayer, Decay, Restarts, out rmsError, out avgRelError, out relClassError);
160      Results.Add(new Result(NeuralNetworkEnsembleClassificationModelResultName, "The neural network ensemble classification solution.", solution));
[10030]161      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the neural network ensemble classification solution on the training set.", new DoubleValue(rmsError)));
162      Results.Add(new Result("Average relative error", "The average of relative errors of the neural network ensemble classification solution on the training set.", new PercentValue(avgRelError)));
[6580]163      Results.Add(new Result("Relative classification error", "The percentage of misclassified samples.", new PercentValue(relClassError)));
[6577]164    }
165
[6580]166    public static IClassificationSolution CreateNeuralNetworkEnsembleClassificationSolution(IClassificationProblemData problemData, int ensembleSize, int nLayers, int nHiddenNodes1, int nHiddenNodes2, double decay, int restarts,
167      out double rmsError, out double avgRelError, out double relClassError) {
[12702]168      var dataset = problemData.Dataset;
[6577]169      string targetVariable = problemData.TargetVariable;
170      IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
[8139]171      IEnumerable<int> rows = problemData.TrainingIndices;
[15142]172      double[,] inputMatrix = dataset.ToArray(allowedInputVariables.Concat(new string[] { targetVariable }), rows);
[15788]173      if (inputMatrix.ContainsNanOrInfinity())
[6580]174        throw new NotSupportedException("Neural network ensemble classification does not support NaN or infinity values in the input dataset.");
[6577]175
[6580]176      int nRows = inputMatrix.GetLength(0);
177      int nFeatures = inputMatrix.GetLength(1) - 1;
[6740]178      double[] classValues = dataset.GetDoubleValues(targetVariable).Distinct().OrderBy(x => x).ToArray();
[6580]179      int nClasses = classValues.Count();
180      // map original class values to values [0..nClasses-1]
[8139]181      Dictionary<double, double> classIndices = new Dictionary<double, double>();
[6580]182      for (int i = 0; i < nClasses; i++) {
[8139]183        classIndices[classValues[i]] = i;
[6580]184      }
185      for (int row = 0; row < nRows; row++) {
[8139]186        inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
[6580]187      }
[6577]188
[6580]189      alglib.mlpensemble mlpEnsemble = null;
[6577]190      if (nLayers == 0) {
[6580]191        alglib.mlpecreatec0(allowedInputVariables.Count(), nClasses, ensembleSize, out mlpEnsemble);
[6577]192      } else if (nLayers == 1) {
[6580]193        alglib.mlpecreatec1(allowedInputVariables.Count(), nHiddenNodes1, nClasses, ensembleSize, out mlpEnsemble);
[6577]194      } else if (nLayers == 2) {
[6580]195        alglib.mlpecreatec2(allowedInputVariables.Count(), nHiddenNodes1, nHiddenNodes2, nClasses, ensembleSize, out mlpEnsemble);
[6577]196      } else throw new ArgumentException("Number of layers must be zero, one, or two.", "nLayers");
197      alglib.mlpreport rep;
198
199      int info;
[6580]200      alglib.mlpetraines(mlpEnsemble, inputMatrix, nRows, decay, restarts, out info, out rep);
[10030]201      if (info != 6) throw new ArgumentException("Error in calculation of neural network ensemble classification solution");
[6577]202
[6580]203      rmsError = alglib.mlpermserror(mlpEnsemble, inputMatrix, nRows);
204      avgRelError = alglib.mlpeavgrelerror(mlpEnsemble, inputMatrix, nRows);
205      relClassError = alglib.mlperelclserror(mlpEnsemble, inputMatrix, nRows);
[6649]206      var problemDataClone = (IClassificationProblemData)problemData.Clone();
[14027]207      return new NeuralNetworkEnsembleClassificationSolution(new NeuralNetworkEnsembleModel(mlpEnsemble, targetVariable, allowedInputVariables, problemDataClone.ClassValues.ToArray()), problemDataClone);
[6577]208    }
209    #endregion
210  }
211}
Note: See TracBrowser for help on using the repository browser.