Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs @ 8786

Last change on this file since 8786 was 8786, checked in by mkommend, 11 years ago

#1968: Added seed and m parameter to random forest modeling.

File size: 10.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  /// <summary>
35  /// Random forest classification data analysis algorithm.
36  /// </summary>
37  [Item("Random Forest Classification", "Random forest classification data analysis algorithm (wrapper for ALGLIB).")]
38  [Creatable("Data Analysis")]
39  [StorableClass]
40  public sealed class RandomForestClassification : FixedDataAnalysisAlgorithm<IClassificationProblem> {
41    private const string RandomForestClassificationModelResultName = "Random forest classification solution";
42    private const string NumberOfTreesParameterName = "Number of trees";
43    private const string RParameterName = "R";
44    private const string MParameterName = "M";
45    private const string SeedParameterName = "Seed";
46    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
47
48    #region parameter properties
49    public IFixedValueParameter<IntValue> NumberOfTreesParameter {
50      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
51    }
52    public IFixedValueParameter<DoubleValue> RParameter {
53      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
54    }
55    public IFixedValueParameter<DoubleValue> MParameter {
56      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
57    }
58    public IFixedValueParameter<IntValue> SeedParameter {
59      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
60    }
61    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
62      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
63    }
64    #endregion
65    #region properties
66    public int NumberOfTrees {
67      get { return NumberOfTreesParameter.Value.Value; }
68      set { NumberOfTreesParameter.Value.Value = value; }
69    }
70    public double R {
71      get { return RParameter.Value.Value; }
72      set { RParameter.Value.Value = value; }
73    }
74    public double M {
75      get { return MParameter.Value.Value; }
76      set { MParameter.Value.Value = value; }
77    }
78    public int Seed {
79      get { return SeedParameter.Value.Value; }
80      set { SeedParameter.Value.Value = value; }
81    }
82    public bool SetSeedRandomly {
83      get { return SetSeedRandomlyParameter.Value.Value; }
84      set { SetSeedRandomlyParameter.Value.Value = value; }
85    }
86    #endregion
87
88    [StorableConstructor]
89    private RandomForestClassification(bool deserializing) : base(deserializing) { }
90    private RandomForestClassification(RandomForestClassification original, Cloner cloner)
91      : base(original, cloner) {
92    }
93
94    public RandomForestClassification()
95      : base() {
96      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50)));
97      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3)));
98      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
99      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
100      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
101      Problem = new ClassificationProblem();
102    }
103
104    [StorableHook(HookType.AfterDeserialization)]
105    private void AfterDeserialization() {
106      if (!Parameters.ContainsKey(MParameterName))
107        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
108      if (!Parameters.ContainsKey(SeedParameterName))
109        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
110      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
111        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
112    }
113
114    public override IDeepCloneable Clone(Cloner cloner) {
115      return new RandomForestClassification(this, cloner);
116    }
117
118    #region random forest
119    protected override void Run() {
120      double rmsError, relClassificationError, outOfBagRmsError, outOfBagRelClassificationError;
121      if (SetSeedRandomly) Seed = new System.Random().Next();
122
123      var solution = CreateRandomForestClassificationSolution(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
124      Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
125      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
126      Results.Add(new Result("Relative classification error", "Relative classification error of the random forest regression solution on the training set.", new PercentValue(relClassificationError)));
127      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
128      Results.Add(new Result("Relative classification error (out-of-bag)", "The out-of-bag relative classification error  of the random forest regression solution.", new PercentValue(outOfBagRelClassificationError)));
129    }
130
131    public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
132      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
133      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1.");
134      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1.");
135
136      lock (alglib.math.rndobject) {
137        alglib.math.rndobject = new System.Random(seed);
138      }
139
140      Dataset dataset = problemData.Dataset;
141      string targetVariable = problemData.TargetVariable;
142      IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
143      IEnumerable<int> rows = problemData.TrainingIndices;
144      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
145      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
146        throw new NotSupportedException("Random forest classification does not support NaN or infinity values in the input dataset.");
147
148      int info = 0;
149      alglib.decisionforest dForest = new alglib.decisionforest();
150      alglib.dfreport rep = new alglib.dfreport(); ;
151      int nRows = inputMatrix.GetLength(0);
152      int nColumns = inputMatrix.GetLength(1);
153      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
154      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
155
156
157      double[] classValues = problemData.ClassValues.ToArray();
158      int nClasses = problemData.Classes;
159      // map original class values to values [0..nClasses-1]
160      Dictionary<double, double> classIndices = new Dictionary<double, double>();
161      for (int i = 0; i < nClasses; i++) {
162        classIndices[classValues[i]] = i;
163      }
164      for (int row = 0; row < nRows; row++) {
165        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
166      }
167      // execute random forest algorithm     
168      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
169      if (info != 1) throw new ArgumentException("Error in calculation of random forest classification solution");
170
171      rmsError = rep.rmserror;
172      outOfBagRmsError = rep.oobrmserror;
173      relClassificationError = rep.relclserror;
174      outOfBagRelClassificationError = rep.oobrelclserror;
175      return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables, classValues));
176    }
177    #endregion
178  }
179}
Note: See TracBrowser for help on using the repository browser.