Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs @ 18086

Last change on this file since 18086 was 18086, checked in by mkommend, 2 years ago

#2521: Merged trunk changes into branch.

File size: 11.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22extern alias alglib_3_7;
23using alglib_3_7;
24using System.Collections.Generic;
25using System.Linq;
26using System.Threading;
27using HEAL.Attic;
28using HeuristicLab.Algorithms.DataAnalysis.RandomForest;
29using HeuristicLab.Common;
30using HeuristicLab.Core;
31using HeuristicLab.Data;
32using HeuristicLab.Optimization;
33using HeuristicLab.Parameters;
34using HeuristicLab.Problems.DataAnalysis;
35
36namespace HeuristicLab.Algorithms.DataAnalysis {
37  /// <summary>
38  /// Random forest regression data analysis algorithm.
39  /// </summary>
40  [Item("Random Forest Regression (RF)", "Random forest regression data analysis algorithm (wrapper for ALGLIB).")]
41  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
42  [StorableType("721CE0EB-82AF-4E49-9900-48E1C67B5E53")]
43  public sealed class RandomForestRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
44    private const string RandomForestRegressionModelResultName = "Random forest regression solution";
45    private const string NumberOfTreesParameterName = "Number of trees";
46    private const string RParameterName = "R";
47    private const string MParameterName = "M";
48    private const string SeedParameterName = "Seed";
49    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
50    private const string ModelCreationParameterName = "ModelCreation";
51
52    #region parameter properties
53    public IFixedValueParameter<IntValue> NumberOfTreesParameter {
54      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
55    }
56    public IFixedValueParameter<DoubleValue> RParameter {
57      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
58    }
59    public IFixedValueParameter<DoubleValue> MParameter {
60      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
61    }
62    public IFixedValueParameter<IntValue> SeedParameter {
63      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
64    }
65    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
66      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
67    }
68    private IFixedValueParameter<EnumValue<ModelCreation>> ModelCreationParameter {
69      get { return (IFixedValueParameter<EnumValue<ModelCreation>>)Parameters[ModelCreationParameterName]; }
70    }
71    #endregion
72    #region properties
73    public int NumberOfTrees {
74      get { return NumberOfTreesParameter.Value.Value; }
75      set { NumberOfTreesParameter.Value.Value = value; }
76    }
77    public double R {
78      get { return RParameter.Value.Value; }
79      set { RParameter.Value.Value = value; }
80    }
81    public double M {
82      get { return MParameter.Value.Value; }
83      set { MParameter.Value.Value = value; }
84    }
85    public int Seed {
86      get { return SeedParameter.Value.Value; }
87      set { SeedParameter.Value.Value = value; }
88    }
89    public bool SetSeedRandomly {
90      get { return SetSeedRandomlyParameter.Value.Value; }
91      set { SetSeedRandomlyParameter.Value.Value = value; }
92    }
93    public ModelCreation ModelCreation {
94      get { return ModelCreationParameter.Value.Value; }
95      set { ModelCreationParameter.Value.Value = value; }
96    }
97    #endregion
98    [StorableConstructor]
99    private RandomForestRegression(StorableConstructorFlag _) : base(_) { }
100    private RandomForestRegression(RandomForestRegression original, Cloner cloner)
101      : base(original, cloner) {
102    }
103
104    public RandomForestRegression()
105      : base() {
106      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50)));
107      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3)));
108      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
109      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
110      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
111      Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
112      Parameters[ModelCreationParameterName].Hidden = true;
113
114      Problem = new RegressionProblem();
115    }
116
117    [StorableHook(HookType.AfterDeserialization)]
118    private void AfterDeserialization() {
119      // BackwardsCompatibility3.3
120      #region Backwards compatible code, remove with 3.4
121      if (!Parameters.ContainsKey(MParameterName))
122        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
123      if (!Parameters.ContainsKey(SeedParameterName))
124        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
125      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
126        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
127
128      // parameter type has been changed
129      if (Parameters.ContainsKey("CreateSolution")) {
130        var createSolutionParam = Parameters["CreateSolution"] as FixedValueParameter<BoolValue>;
131        Parameters.Remove(createSolutionParam);
132
133        ModelCreation value = createSolutionParam.Value.Value ? ModelCreation.Model : ModelCreation.QualityOnly;
134        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value)));
135        Parameters[ModelCreationParameterName].Hidden = true;
136      } else if (!Parameters.ContainsKey(ModelCreationParameterName)) {
137        // very old version contains neither ModelCreationParameter nor CreateSolutionParameter
138        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
139        Parameters[ModelCreationParameterName].Hidden = true;
140      }
141      #endregion
142    }
143
144    public override IDeepCloneable Clone(Cloner cloner) {
145      return new RandomForestRegression(this, cloner);
146    }
147
148    #region random forest
149    protected override void Run(CancellationToken cancellationToken) {
150      double rmsError, avgRelError, outOfBagRmsError, outOfBagAvgRelError;
151      if (SetSeedRandomly) Seed = Random.RandomSeedGenerator.GetSeed();
152      var model = CreateRandomForestRegressionModel(Problem.ProblemData, NumberOfTrees, R, M, Seed,
153        out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
154
155      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
156      Results.Add(new Result("Average relative error", "The average of relative errors of the random forest regression solution on the training set.", new PercentValue(avgRelError)));
157      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
158      Results.Add(new Result("Average relative error (out-of-bag)", "The out-of-bag average of relative errors of the random forest regression solution.", new PercentValue(outOfBagAvgRelError)));
159
160      IRegressionSolution solution = null;
161      if (ModelCreation == ModelCreation.Model) {
162        solution = model.CreateRegressionSolution(Problem.ProblemData);
163      } else if (ModelCreation == ModelCreation.SurrogateModel) {
164        var problemData = Problem.ProblemData;
165        var surrogateModel = new RandomForestModelSurrogate(model, problemData.TargetVariable, problemData, Seed, NumberOfTrees, R, M);
166        solution = surrogateModel.CreateRegressionSolution(problemData);
167      }
168
169      if (solution != null) {
170        Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
171      }
172    }
173
174
175    // keep for compatibility with old API
176    public static RandomForestRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
177      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
178      var model = CreateRandomForestRegressionModel(problemData, nTrees, r, m, seed,
179        out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
180      return new RandomForestRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
181    }
182
183    public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees,
184     double r, double m, int seed,
185     out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
186      var model = CreateRandomForestRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
187      return model;
188    }
189
190    public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
191    out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
192
193      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
194      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
195
196      alglib.dfreport rep;
197      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
198
199      rmsError = rep.rmserror;
200      outOfBagRmsError = rep.oobrmserror;
201      avgRelError = rep.avgrelerror;
202      outOfBagAvgRelError = rep.oobavgrelerror;
203
204      return new RandomForestModelFull(dForest, nTrees, problemData.TargetVariable, problemData.AllowedInputVariables);
205    }
206
207    #endregion
208  }
209}
Note: See TracBrowser for help on using the repository browser.