source: branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs @ 17045

Last change on this file since 17045 was 17045, checked in by mkommend, 2 years ago

#2952: Intermediate commit of refactoring RF models that is not yet finished.

File size: 10.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using System.Threading;
25using HEAL.Attic;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HeuristicLab.Problems.DataAnalysis;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  /// <summary>
35  /// Random forest regression data analysis algorithm.
36  /// </summary>
37  [Item("Random Forest Regression (RF)", "Random forest regression data analysis algorithm (wrapper for ALGLIB).")]
38  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
39  [StorableType("721CE0EB-82AF-4E49-9900-48E1C67B5E53")]
40  public sealed class RandomForestRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
41    private const string RandomForestRegressionModelResultName = "Random forest regression solution";
42    private const string NumberOfTreesParameterName = "Number of trees";
43    private const string RParameterName = "R";
44    private const string MParameterName = "M";
45    private const string SeedParameterName = "Seed";
46    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
47    private const string CreateSolutionParameterName = "CreateSolution";
48
49    #region parameter properties
50    public IFixedValueParameter<IntValue> NumberOfTreesParameter {
51      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
52    }
53    public IFixedValueParameter<DoubleValue> RParameter {
54      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
55    }
56    public IFixedValueParameter<DoubleValue> MParameter {
57      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
58    }
59    public IFixedValueParameter<IntValue> SeedParameter {
60      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
61    }
62    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
63      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
64    }
65    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
66      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
67    }
68    #endregion
69    #region properties
70    public int NumberOfTrees {
71      get { return NumberOfTreesParameter.Value.Value; }
72      set { NumberOfTreesParameter.Value.Value = value; }
73    }
74    public double R {
75      get { return RParameter.Value.Value; }
76      set { RParameter.Value.Value = value; }
77    }
78    public double M {
79      get { return MParameter.Value.Value; }
80      set { MParameter.Value.Value = value; }
81    }
82    public int Seed {
83      get { return SeedParameter.Value.Value; }
84      set { SeedParameter.Value.Value = value; }
85    }
86    public bool SetSeedRandomly {
87      get { return SetSeedRandomlyParameter.Value.Value; }
88      set { SetSeedRandomlyParameter.Value.Value = value; }
89    }
90    public bool CreateSolution {
91      get { return CreateSolutionParameter.Value.Value; }
92      set { CreateSolutionParameter.Value.Value = value; }
93    }
94    #endregion
95    [StorableConstructor]
96    private RandomForestRegression(StorableConstructorFlag _) : base(_) { }
97    private RandomForestRegression(RandomForestRegression original, Cloner cloner)
98      : base(original, cloner) {
99    }
100
101    public RandomForestRegression()
102      : base() {
103      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50)));
104      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3)));
105      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
106      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
107      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
108      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
109      Parameters[CreateSolutionParameterName].Hidden = true;
110
111      Problem = new RegressionProblem();
112    }
113
114    [StorableHook(HookType.AfterDeserialization)]
115    private void AfterDeserialization() {
116      // BackwardsCompatibility3.3
117      #region Backwards compatible code, remove with 3.4
118      if (!Parameters.ContainsKey(MParameterName))
119        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
120      if (!Parameters.ContainsKey(SeedParameterName))
121        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
122      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
123        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
124      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
125        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
126        Parameters[CreateSolutionParameterName].Hidden = true;
127      }
128      #endregion
129    }
130
131    public override IDeepCloneable Clone(Cloner cloner) {
132      return new RandomForestRegression(this, cloner);
133    }
134
135    #region random forest
136    protected override void Run(CancellationToken cancellationToken) {
137      double rmsError, avgRelError, outOfBagRmsError, outOfBagAvgRelError;
138      if (SetSeedRandomly) Seed = Random.RandomSeedGenerator.GetSeed();
139      var model = CreateRandomForestRegressionModel(Problem.ProblemData, NumberOfTrees, R, M, Seed,
140        out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
141
142      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
143      Results.Add(new Result("Average relative error", "The average of relative errors of the random forest regression solution on the training set.", new PercentValue(avgRelError)));
144      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
145      Results.Add(new Result("Average relative error (out-of-bag)", "The out-of-bag average of relative errors of the random forest regression solution.", new PercentValue(outOfBagAvgRelError)));
146
147      if (CreateSolution) {
148        var solution = model.CreateRegressionSolution(Problem.ProblemData);
149        Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
150      }
151    }
152
153
154    // keep for compatibility with old API
155    public static RandomForestRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
156      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
157      var model = CreateRandomForestRegressionModel(problemData, nTrees, r, m, seed,
158        out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
159      return new RandomForestRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
160    }
161
162    public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees,
163     double r, double m, int seed,
164     out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
165      return CreateRandomForestRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
166    }
167
168    public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
169    out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
170
171      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
172      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
173
174      alglib.dfreport rep;
175      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
176
177      rmsError = rep.rmserror;
178      outOfBagRmsError = rep.oobrmserror;
179      avgRelError = rep.avgrelerror;
180      outOfBagAvgRelError = rep.oobavgrelerror;
181
182      return new RandomForestModelFull(dForest, problemData.TargetVariable, problemData.AllowedInputVariables);
183
184      //return RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed,
185      //rmsError: out rmsError, avgRelError: out avgRelError, outOfBagRmsError: out outOfBagRmsError, outOfBagAvgRelError: out outOfBagAvgRelError);
186    }
187
188    #endregion
189  }
190}
Note: See TracBrowser for help on using the repository browser.