Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2434_crossvalidation/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs @ 15802

Last change on this file since 15802 was 14029, checked in by gkronber, 8 years ago

#2434: merged trunk changes r12934:14026 from trunk to branch

File size: 9.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using HeuristicLab.Common;
23using HeuristicLab.Core;
24using HeuristicLab.Data;
25using HeuristicLab.Optimization;
26using HeuristicLab.Parameters;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
32  /// Random forest regression data analysis algorithm.
33  /// </summary>
34  [Item("Random Forest Regression (RF)", "Random forest regression data analysis algorithm (wrapper for ALGLIB).")]
35  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
36  [StorableClass]
37  public sealed class RandomForestRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
38    private const string RandomForestRegressionModelResultName = "Random forest regression solution";
39    private const string NumberOfTreesParameterName = "Number of trees";
40    private const string RParameterName = "R";
41    private const string MParameterName = "M";
42    private const string SeedParameterName = "Seed";
43    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
44    private const string CreateSolutionParameterName = "CreateSolution";
45
46    #region parameter properties
47    public IFixedValueParameter<IntValue> NumberOfTreesParameter {
48      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
49    }
50    public IFixedValueParameter<DoubleValue> RParameter {
51      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
52    }
53    public IFixedValueParameter<DoubleValue> MParameter {
54      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
55    }
56    public IFixedValueParameter<IntValue> SeedParameter {
57      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
58    }
59    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
60      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
61    }
62    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
63      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
64    }
65    #endregion
66    #region properties
67    public int NumberOfTrees {
68      get { return NumberOfTreesParameter.Value.Value; }
69      set { NumberOfTreesParameter.Value.Value = value; }
70    }
71    public double R {
72      get { return RParameter.Value.Value; }
73      set { RParameter.Value.Value = value; }
74    }
75    public double M {
76      get { return MParameter.Value.Value; }
77      set { MParameter.Value.Value = value; }
78    }
79    public int Seed {
80      get { return SeedParameter.Value.Value; }
81      set { SeedParameter.Value.Value = value; }
82    }
83    public bool SetSeedRandomly {
84      get { return SetSeedRandomlyParameter.Value.Value; }
85      set { SetSeedRandomlyParameter.Value.Value = value; }
86    }
87    public bool CreateSolution {
88      get { return CreateSolutionParameter.Value.Value; }
89      set { CreateSolutionParameter.Value.Value = value; }
90    }
91    #endregion
92    [StorableConstructor]
93    private RandomForestRegression(bool deserializing) : base(deserializing) { }
94    private RandomForestRegression(RandomForestRegression original, Cloner cloner)
95      : base(original, cloner) {
96    }
97
98    public RandomForestRegression()
99      : base() {
100      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50)));
101      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3)));
102      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
103      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
104      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
105      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
106      Parameters[CreateSolutionParameterName].Hidden = true;
107
108      Problem = new RegressionProblem();
109    }
110
111    [StorableHook(HookType.AfterDeserialization)]
112    private void AfterDeserialization() {
113      // BackwardsCompatibility3.3
114      #region Backwards compatible code, remove with 3.4
115      if (!Parameters.ContainsKey(MParameterName))
116        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
117      if (!Parameters.ContainsKey(SeedParameterName))
118        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
119      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
120        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
121      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
122        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
123        Parameters[CreateSolutionParameterName].Hidden = true;
124      }
125      #endregion
126    }
127
128    public override IDeepCloneable Clone(Cloner cloner) {
129      return new RandomForestRegression(this, cloner);
130    }
131
132    #region random forest
133    protected override void Run() {
134      double rmsError, avgRelError, outOfBagRmsError, outOfBagAvgRelError;
135      if (SetSeedRandomly) Seed = new System.Random().Next();
136      var model = CreateRandomForestRegressionModel(Problem.ProblemData, NumberOfTrees, R, M, Seed,
137        out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
138
139      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
140      Results.Add(new Result("Average relative error", "The average of relative errors of the random forest regression solution on the training set.", new PercentValue(avgRelError)));
141      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
142      Results.Add(new Result("Average relative error (out-of-bag)", "The out-of-bag average of relative errors of the random forest regression solution.", new PercentValue(outOfBagAvgRelError)));
143
144      if (CreateSolution) {
145        var solution = new RandomForestRegressionSolution(model, (IRegressionProblemData)Problem.ProblemData.Clone());
146        Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
147      }
148    }
149
150    // keep for compatibility with old API
151    public static RandomForestRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
152      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
153      var model = CreateRandomForestRegressionModel(problemData, nTrees, r, m, seed,
154        out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
155      return new RandomForestRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
156    }
157
158    public static RandomForestModel CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees,
159      double r, double m, int seed,
160      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
161      return RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
162    }
163
164    #endregion
165  }
166}
Note: See TracBrowser for help on using the repository browser.