source: branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs @ 17152

Last change on this file since 17152 was 17152, checked in by gkronber, 2 months ago

#2952: readded backwards compatibility code for very old versions of GBT and RF

File size: 12.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using System.Threading;
25using HEAL.Attic;
26using HeuristicLab.Algorithms.DataAnalysis.RandomForest;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Problems.DataAnalysis;
33
34namespace HeuristicLab.Algorithms.DataAnalysis {
35  /// <summary>
36  /// Random forest classification data analysis algorithm.
37  /// </summary>
38  [Item("Random Forest Classification (RF)", "Random forest classification data analysis algorithm (wrapper for ALGLIB).")]
39  [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 120)]
40  [StorableType("73070CC7-E85E-4851-9F26-C537AE1CC1C0")]
41  public sealed class RandomForestClassification : FixedDataAnalysisAlgorithm<IClassificationProblem> {
42    private const string RandomForestClassificationModelResultName = "Random forest classification solution";
43    private const string NumberOfTreesParameterName = "Number of trees";
44    private const string RParameterName = "R";
45    private const string MParameterName = "M";
46    private const string SeedParameterName = "Seed";
47    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
48    private const string ModelCreationParameterName = "ModelCreation";
49
50    #region parameter properties
51    public IFixedValueParameter<IntValue> NumberOfTreesParameter {
52      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
53    }
54    public IFixedValueParameter<DoubleValue> RParameter {
55      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
56    }
57    public IFixedValueParameter<DoubleValue> MParameter {
58      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
59    }
60    public IFixedValueParameter<IntValue> SeedParameter {
61      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
62    }
63    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
64      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
65    }
66    private IFixedValueParameter<EnumValue<ModelCreation>> ModelCreationParameter {
67      get { return (IFixedValueParameter<EnumValue<ModelCreation>>)Parameters[ModelCreationParameterName]; }
68    }
69    #endregion
70    #region properties
71    public int NumberOfTrees {
72      get { return NumberOfTreesParameter.Value.Value; }
73      set { NumberOfTreesParameter.Value.Value = value; }
74    }
75    public double R {
76      get { return RParameter.Value.Value; }
77      set { RParameter.Value.Value = value; }
78    }
79    public double M {
80      get { return MParameter.Value.Value; }
81      set { MParameter.Value.Value = value; }
82    }
83    public int Seed {
84      get { return SeedParameter.Value.Value; }
85      set { SeedParameter.Value.Value = value; }
86    }
87    public bool SetSeedRandomly {
88      get { return SetSeedRandomlyParameter.Value.Value; }
89      set { SetSeedRandomlyParameter.Value.Value = value; }
90    }
91    public ModelCreation ModelCreation {
92      get { return ModelCreationParameter.Value.Value; }
93      set { ModelCreationParameter.Value.Value = value; }
94    }
95    #endregion
96
97    [StorableConstructor]
98    private RandomForestClassification(StorableConstructorFlag _) : base(_) { }
99    private RandomForestClassification(RandomForestClassification original, Cloner cloner)
100      : base(original, cloner) {
101    }
102
103    public RandomForestClassification()
104      : base() {
105      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50)));
106      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3)));
107      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
108      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
109      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
110      Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
111      Parameters[ModelCreationParameterName].Hidden = true;
112
113      Problem = new ClassificationProblem();
114    }
115
116    [StorableHook(HookType.AfterDeserialization)]
117    private void AfterDeserialization() {
118      // BackwardsCompatibility3.3
119      #region Backwards compatible code, remove with 3.4
120      if (!Parameters.ContainsKey(MParameterName))
121        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
122      if (!Parameters.ContainsKey(SeedParameterName))
123        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
124      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
125        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
126
127      // parameter type has been changed
128      if (Parameters.ContainsKey("CreateSolution")) {
129        var createSolutionParam = Parameters["CreateSolution"] as FixedValueParameter<BoolValue>;
130        Parameters.Remove(createSolutionParam);
131
132        ModelCreation value = createSolutionParam.Value.Value ? ModelCreation.Model : ModelCreation.QualityOnly;
133        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value)));
134        Parameters[ModelCreationParameterName].Hidden = true;
135      } else if (!Parameters.ContainsKey(ModelCreationParameterName)) {
136        // very old version contains neither ModelCreationParameter nor CreateSolutionParameter
137        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
138      }
139      #endregion
140    }
141
142    public override IDeepCloneable Clone(Cloner cloner) {
143      return new RandomForestClassification(this, cloner);
144    }
145
146    #region random forest
147    protected override void Run(CancellationToken cancellationToken) {
148      double rmsError, relClassificationError, outOfBagRmsError, outOfBagRelClassificationError;
149      if (SetSeedRandomly) Seed = Random.RandomSeedGenerator.GetSeed();
150
151      var model = CreateRandomForestClassificationModel(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
152
153      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
154      Results.Add(new Result("Relative classification error", "Relative classification error of the random forest regression solution on the training set.", new PercentValue(relClassificationError)));
155      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
156      Results.Add(new Result("Relative classification error (out-of-bag)", "The out-of-bag relative classification error  of the random forest regression solution.", new PercentValue(outOfBagRelClassificationError)));
157
158
159      IClassificationSolution solution = null;
160      if (ModelCreation == ModelCreation.Model) {
161        solution = model.CreateClassificationSolution(Problem.ProblemData);
162      } else if (ModelCreation == ModelCreation.SurrogateModel) {
163        var problemData = Problem.ProblemData;
164        var surrogateModel = new RandomForestModelSurrogate(model, problemData.TargetVariable, problemData, Seed, NumberOfTrees, R, M, problemData.ClassValues.ToArray());
165
166        solution = surrogateModel.CreateClassificationSolution(problemData);
167      }
168
169      if (solution != null) {
170        Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
171      }
172    }
173
174    // keep for compatibility with old API
175    public static RandomForestClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
176      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
177      var model = CreateRandomForestClassificationModel(problemData, nTrees, r, m, seed,
178        out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
179      return new RandomForestClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
180    }
181
182    public static RandomForestModelFull CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
183 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
184      var model = CreateRandomForestClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
185      return model;
186    }
187
188    public static RandomForestModelFull CreateRandomForestClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
189      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
190
191      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
192      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
193
194      var classValues = problemData.ClassValues.ToArray();
195      int nClasses = classValues.Length;
196
197      // map original class values to values [0..nClasses-1]
198      var classIndices = new Dictionary<double, double>();
199      for (int i = 0; i < nClasses; i++) {
200        classIndices[classValues[i]] = i;
201      }
202
203      int nRows = inputMatrix.GetLength(0);
204      int nColumns = inputMatrix.GetLength(1);
205      for (int row = 0; row < nRows; row++) {
206        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
207      }
208
209      alglib.dfreport rep;
210      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
211
212      rmsError = rep.rmserror;
213      outOfBagRmsError = rep.oobrmserror;
214      relClassificationError = rep.relclserror;
215      outOfBagRelClassificationError = rep.oobrelclserror;
216
217      return new RandomForestModelFull(dForest, problemData.TargetVariable, problemData.AllowedInputVariables, classValues);
218    }
219    #endregion
220  }
221}
Note: See TracBrowser for help on using the repository browser.