Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs @ 17928

Last change on this file since 17928 was 17181, checked in by swagner, 5 years ago

#2875: Merged r17180 from trunk to stable

File size: 12.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using System.Threading;
25using HEAL.Attic;
26using HeuristicLab.Algorithms.DataAnalysis.RandomForest;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Problems.DataAnalysis;
33
34namespace HeuristicLab.Algorithms.DataAnalysis {
35  /// <summary>
36  /// Random forest classification data analysis algorithm.
37  /// </summary>
38  [Item("Random Forest Classification (RF)", "Random forest classification data analysis algorithm (wrapper for ALGLIB).")]
39  [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 120)]
40  [StorableType("73070CC7-E85E-4851-9F26-C537AE1CC1C0")]
41  public sealed class RandomForestClassification : FixedDataAnalysisAlgorithm<IClassificationProblem> {
42    private const string RandomForestClassificationModelResultName = "Random forest classification solution";
43    private const string NumberOfTreesParameterName = "Number of trees";
44    private const string RParameterName = "R";
45    private const string MParameterName = "M";
46    private const string SeedParameterName = "Seed";
47    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
48    private const string ModelCreationParameterName = "ModelCreation";
49
50    #region parameter properties
51    public IFixedValueParameter<IntValue> NumberOfTreesParameter {
52      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
53    }
54    public IFixedValueParameter<DoubleValue> RParameter {
55      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
56    }
57    public IFixedValueParameter<DoubleValue> MParameter {
58      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
59    }
60    public IFixedValueParameter<IntValue> SeedParameter {
61      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
62    }
63    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
64      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
65    }
66    private IFixedValueParameter<EnumValue<ModelCreation>> ModelCreationParameter {
67      get { return (IFixedValueParameter<EnumValue<ModelCreation>>)Parameters[ModelCreationParameterName]; }
68    }
69    #endregion
70    #region properties
71    public int NumberOfTrees {
72      get { return NumberOfTreesParameter.Value.Value; }
73      set { NumberOfTreesParameter.Value.Value = value; }
74    }
75    public double R {
76      get { return RParameter.Value.Value; }
77      set { RParameter.Value.Value = value; }
78    }
79    public double M {
80      get { return MParameter.Value.Value; }
81      set { MParameter.Value.Value = value; }
82    }
83    public int Seed {
84      get { return SeedParameter.Value.Value; }
85      set { SeedParameter.Value.Value = value; }
86    }
87    public bool SetSeedRandomly {
88      get { return SetSeedRandomlyParameter.Value.Value; }
89      set { SetSeedRandomlyParameter.Value.Value = value; }
90    }
91    public ModelCreation ModelCreation {
92      get { return ModelCreationParameter.Value.Value; }
93      set { ModelCreationParameter.Value.Value = value; }
94    }
95    #endregion
96
97    [StorableConstructor]
98    private RandomForestClassification(StorableConstructorFlag _) : base(_) { }
99    private RandomForestClassification(RandomForestClassification original, Cloner cloner)
100      : base(original, cloner) {
101    }
102
103    public RandomForestClassification()
104      : base() {
105      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50)));
106      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3)));
107      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
108      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
109      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
110      Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
111      Parameters[ModelCreationParameterName].Hidden = true;
112
113      Problem = new ClassificationProblem();
114    }
115
116    [StorableHook(HookType.AfterDeserialization)]
117    private void AfterDeserialization() {
118      // BackwardsCompatibility3.3
119      #region Backwards compatible code, remove with 3.4
120      if (!Parameters.ContainsKey(MParameterName))
121        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
122      if (!Parameters.ContainsKey(SeedParameterName))
123        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
124      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
125        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
126
127      // parameter type has been changed
128      if (Parameters.ContainsKey("CreateSolution")) {
129        var createSolutionParam = Parameters["CreateSolution"] as FixedValueParameter<BoolValue>;
130        Parameters.Remove(createSolutionParam);
131
132        ModelCreation value = createSolutionParam.Value.Value ? ModelCreation.Model : ModelCreation.QualityOnly;
133        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value)));
134        Parameters[ModelCreationParameterName].Hidden = true;
135      } else if (!Parameters.ContainsKey(ModelCreationParameterName)) {
136        // very old version contains neither ModelCreationParameter nor CreateSolutionParameter
137        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
138        Parameters[ModelCreationParameterName].Hidden = true;
139      }
140      #endregion
141    }
142
143    public override IDeepCloneable Clone(Cloner cloner) {
144      return new RandomForestClassification(this, cloner);
145    }
146
147    #region random forest
148    protected override void Run(CancellationToken cancellationToken) {
149      double rmsError, relClassificationError, outOfBagRmsError, outOfBagRelClassificationError;
150      if (SetSeedRandomly) Seed = Random.RandomSeedGenerator.GetSeed();
151
152      var model = CreateRandomForestClassificationModel(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
153
154      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
155      Results.Add(new Result("Relative classification error", "Relative classification error of the random forest regression solution on the training set.", new PercentValue(relClassificationError)));
156      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
157      Results.Add(new Result("Relative classification error (out-of-bag)", "The out-of-bag relative classification error  of the random forest regression solution.", new PercentValue(outOfBagRelClassificationError)));
158
159
160      IClassificationSolution solution = null;
161      if (ModelCreation == ModelCreation.Model) {
162        solution = model.CreateClassificationSolution(Problem.ProblemData);
163      } else if (ModelCreation == ModelCreation.SurrogateModel) {
164        var problemData = Problem.ProblemData;
165        var surrogateModel = new RandomForestModelSurrogate(model, problemData.TargetVariable, problemData, Seed, NumberOfTrees, R, M, problemData.ClassValues.ToArray());
166
167        solution = surrogateModel.CreateClassificationSolution(problemData);
168      }
169
170      if (solution != null) {
171        Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
172      }
173    }
174
175    // keep for compatibility with old API
176    public static RandomForestClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
177      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
178      var model = CreateRandomForestClassificationModel(problemData, nTrees, r, m, seed,
179        out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
180      return new RandomForestClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
181    }
182
183    public static RandomForestModelFull CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
184 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
185      var model = CreateRandomForestClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
186      return model;
187    }
188
189    public static RandomForestModelFull CreateRandomForestClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
190      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
191
192      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
193      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
194
195      var classValues = problemData.ClassValues.ToArray();
196      int nClasses = classValues.Length;
197
198      // map original class values to values [0..nClasses-1]
199      var classIndices = new Dictionary<double, double>();
200      for (int i = 0; i < nClasses; i++) {
201        classIndices[classValues[i]] = i;
202      }
203
204      int nRows = inputMatrix.GetLength(0);
205      int nColumns = inputMatrix.GetLength(1);
206      for (int row = 0; row < nRows; row++) {
207        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
208      }
209
210      alglib.dfreport rep;
211      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
212
213      rmsError = rep.rmserror;
214      outOfBagRmsError = rep.oobrmserror;
215      relClassificationError = rep.relclserror;
216      outOfBagRelClassificationError = rep.oobrelclserror;
217
218      return new RandomForestModelFull(dForest, problemData.TargetVariable, problemData.AllowedInputVariables, classValues);
219    }
220    #endregion
221  }
222}
Note: See TracBrowser for help on using the repository browser.