Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs @ 17050

Last change on this file since 17050 was 17050, checked in by mkommend, 5 years ago

#2952: Finished implemenation of different RF models.

File size: 11.3 KB
RevLine 
[6240]1#region License Information
2/* HeuristicLab
[16565]3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[6240]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
[17045]22using System.Collections.Generic;
23using System.Linq;
[14523]24using System.Threading;
[17045]25using HEAL.Attic;
[17050]26using HeuristicLab.Algorithms.DataAnalysis.RandomForest;
[6240]27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Optimization;
[8786]31using HeuristicLab.Parameters;
[6240]32using HeuristicLab.Problems.DataAnalysis;
33
34namespace HeuristicLab.Algorithms.DataAnalysis {
35  /// <summary>
36  /// Random forest regression data analysis algorithm.
37  /// </summary>
[13238]38  [Item("Random Forest Regression (RF)", "Random forest regression data analysis algorithm (wrapper for ALGLIB).")]
[12504]39  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
[16565]40  [StorableType("721CE0EB-82AF-4E49-9900-48E1C67B5E53")]
[6240]41  public sealed class RandomForestRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
42    private const string RandomForestRegressionModelResultName = "Random forest regression solution";
43    private const string NumberOfTreesParameterName = "Number of trees";
44    private const string RParameterName = "R";
[8786]45    private const string MParameterName = "M";
46    private const string SeedParameterName = "Seed";
47    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
[17050]48    private const string ModelCreationParameterName = "ModelCreation";
[8786]49
[6240]50    #region parameter properties
[8786]51    public IFixedValueParameter<IntValue> NumberOfTreesParameter {
52      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
[6240]53    }
[8786]54    public IFixedValueParameter<DoubleValue> RParameter {
55      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
[6240]56    }
[8786]57    public IFixedValueParameter<DoubleValue> MParameter {
58      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
59    }
60    public IFixedValueParameter<IntValue> SeedParameter {
61      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
62    }
63    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
64      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
65    }
[17050]66    private IFixedValueParameter<EnumValue<ModelCreation>> ModelCreationParameter {
67      get { return (IFixedValueParameter<EnumValue<ModelCreation>>)Parameters[ModelCreationParameterName]; }
[13204]68    }
[6240]69    #endregion
70    #region properties
71    public int NumberOfTrees {
72      get { return NumberOfTreesParameter.Value.Value; }
73      set { NumberOfTreesParameter.Value.Value = value; }
74    }
75    public double R {
76      get { return RParameter.Value.Value; }
77      set { RParameter.Value.Value = value; }
78    }
[8786]79    public double M {
80      get { return MParameter.Value.Value; }
81      set { MParameter.Value.Value = value; }
82    }
83    public int Seed {
84      get { return SeedParameter.Value.Value; }
85      set { SeedParameter.Value.Value = value; }
86    }
87    public bool SetSeedRandomly {
88      get { return SetSeedRandomlyParameter.Value.Value; }
89      set { SetSeedRandomlyParameter.Value.Value = value; }
90    }
[17050]91    public ModelCreation ModelCreation {
92      get { return ModelCreationParameter.Value.Value; }
93      set { ModelCreationParameter.Value.Value = value; }
[13204]94    }
[6240]95    #endregion
96    [StorableConstructor]
[16565]97    private RandomForestRegression(StorableConstructorFlag _) : base(_) { }
[6240]98    private RandomForestRegression(RandomForestRegression original, Cloner cloner)
99      : base(original, cloner) {
100    }
[8786]101
[6240]102    public RandomForestRegression()
103      : base() {
104      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50)));
105      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3)));
[8786]106      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
107      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
108      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
[17050]109      Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
110      Parameters[ModelCreationParameterName].Hidden = true;
[13204]111
[6240]112      Problem = new RegressionProblem();
113    }
[8786]114
[6240]115    [StorableHook(HookType.AfterDeserialization)]
[8786]116    private void AfterDeserialization() {
[13204]117      // BackwardsCompatibility3.3
118      #region Backwards compatible code, remove with 3.4
[8786]119      if (!Parameters.ContainsKey(MParameterName))
120        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
121      if (!Parameters.ContainsKey(SeedParameterName))
122        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
123      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
124        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
[17050]125
126      // parameter type has been changed
127      if (Parameters.ContainsKey("CreateSolution")) {
128        var createSolutionParam = Parameters["CreateSolution"] as FixedValueParameter<BoolValue>;
129        Parameters.Remove(createSolutionParam);
130
131        ModelCreation value = createSolutionParam.Value.Value ? ModelCreation.Model : ModelCreation.QualityOnly;
132        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value)));
133        Parameters[ModelCreationParameterName].Hidden = true;
[13204]134      }
135      #endregion
[8786]136    }
[6240]137
138    public override IDeepCloneable Clone(Cloner cloner) {
139      return new RandomForestRegression(this, cloner);
140    }
141
142    #region random forest
[14523]143    protected override void Run(CancellationToken cancellationToken) {
[6240]144      double rmsError, avgRelError, outOfBagRmsError, outOfBagAvgRelError;
[16071]145      if (SetSeedRandomly) Seed = Random.RandomSeedGenerator.GetSeed();
[13204]146      var model = CreateRandomForestRegressionModel(Problem.ProblemData, NumberOfTrees, R, M, Seed,
147        out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
[8786]148
[6240]149      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
[6241]150      Results.Add(new Result("Average relative error", "The average of relative errors of the random forest regression solution on the training set.", new PercentValue(avgRelError)));
151      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
152      Results.Add(new Result("Average relative error (out-of-bag)", "The out-of-bag average of relative errors of the random forest regression solution.", new PercentValue(outOfBagAvgRelError)));
[13204]153
[17050]154      IRegressionSolution solution = null;
155      if (ModelCreation == ModelCreation.Model) {
156        solution = model.CreateRegressionSolution(Problem.ProblemData);
157      } else if (ModelCreation == ModelCreation.SurrogateModel) {
158        var problemData = Problem.ProblemData;
159        var surrogateModel = new RandomForestModelSurrogate(model, problemData.TargetVariable, problemData, Seed, NumberOfTrees, R, M);
160        solution = surrogateModel.CreateRegressionSolution(problemData);
161      }
162
163      if (solution != null) {
[13204]164        Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
165      }
[6240]166    }
167
[17045]168
[13204]169    // keep for compatibility with old API
170    public static RandomForestRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
[6240]171      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
[13204]172      var model = CreateRandomForestRegressionModel(problemData, nTrees, r, m, seed,
173        out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
[13941]174      return new RandomForestRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
[6240]175    }
[13204]176
[17045]177    public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees,
178     double r, double m, int seed,
179     out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
[17050]180      var model = CreateRandomForestRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
181      return model;
[13204]182    }
183
[17045]184    public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
185    out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
186
187      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
188      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
189
190      alglib.dfreport rep;
191      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
192
193      rmsError = rep.rmserror;
194      outOfBagRmsError = rep.oobrmserror;
195      avgRelError = rep.avgrelerror;
196      outOfBagAvgRelError = rep.oobavgrelerror;
197
198      return new RandomForestModelFull(dForest, problemData.TargetVariable, problemData.AllowedInputVariables);
199    }
200
[6240]201    #endregion
202  }
203}
Note: See TracBrowser for help on using the repository browser.