Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.TimeSeries/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs @ 8789

Last change on this file since 8789 was 8789, checked in by mkommend, 11 years ago

#1081: Merged trunk changes into timeseries branch.

File size: 9.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  /// <summary>
35  /// Random forest regression data analysis algorithm.
36  /// </summary>
37  [Item("Random Forest Regression", "Random forest regression data analysis algorithm (wrapper for ALGLIB).")]
38  [Creatable("Data Analysis")]
39  [StorableClass]
40  public sealed class RandomForestRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
41    private const string RandomForestRegressionModelResultName = "Random forest regression solution";
42    private const string NumberOfTreesParameterName = "Number of trees";
43    private const string RParameterName = "R";
44    private const string MParameterName = "M";
45    private const string SeedParameterName = "Seed";
46    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
47
48    #region parameter properties
49    public IFixedValueParameter<IntValue> NumberOfTreesParameter {
50      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
51    }
52    public IFixedValueParameter<DoubleValue> RParameter {
53      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
54    }
55    public IFixedValueParameter<DoubleValue> MParameter {
56      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
57    }
58    public IFixedValueParameter<IntValue> SeedParameter {
59      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
60    }
61    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
62      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
63    }
64    #endregion
65    #region properties
66    public int NumberOfTrees {
67      get { return NumberOfTreesParameter.Value.Value; }
68      set { NumberOfTreesParameter.Value.Value = value; }
69    }
70    public double R {
71      get { return RParameter.Value.Value; }
72      set { RParameter.Value.Value = value; }
73    }
74    public double M {
75      get { return MParameter.Value.Value; }
76      set { MParameter.Value.Value = value; }
77    }
78    public int Seed {
79      get { return SeedParameter.Value.Value; }
80      set { SeedParameter.Value.Value = value; }
81    }
82    public bool SetSeedRandomly {
83      get { return SetSeedRandomlyParameter.Value.Value; }
84      set { SetSeedRandomlyParameter.Value.Value = value; }
85    }
86    #endregion
87    [StorableConstructor]
88    private RandomForestRegression(bool deserializing) : base(deserializing) { }
89    private RandomForestRegression(RandomForestRegression original, Cloner cloner)
90      : base(original, cloner) {
91    }
92
93    public RandomForestRegression()
94      : base() {
95      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50)));
96      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3)));
97      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
98      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
99      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
100      Problem = new RegressionProblem();
101    }
102
103    [StorableHook(HookType.AfterDeserialization)]
104    private void AfterDeserialization() {
105      if (!Parameters.ContainsKey(MParameterName))
106        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
107      if (!Parameters.ContainsKey(SeedParameterName))
108        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
109      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
110        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
111    }
112
113    public override IDeepCloneable Clone(Cloner cloner) {
114      return new RandomForestRegression(this, cloner);
115    }
116
117    #region random forest
118    protected override void Run() {
119      double rmsError, avgRelError, outOfBagRmsError, outOfBagAvgRelError;
120      if (SetSeedRandomly) Seed = new System.Random().Next();
121
122      var solution = CreateRandomForestRegressionSolution(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
123      Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
124      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
125      Results.Add(new Result("Average relative error", "The average of relative errors of the random forest regression solution on the training set.", new PercentValue(avgRelError)));
126      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
127      Results.Add(new Result("Average relative error (out-of-bag)", "The out-of-bag average of relative errors of the random forest regression solution.", new PercentValue(outOfBagAvgRelError)));
128    }
129
130    public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
131      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
132      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1.");
133      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1.");
134
135      lock (alglib.math.rndobject) {
136        alglib.math.rndobject = new System.Random(seed);
137      }
138
139      Dataset dataset = problemData.Dataset;
140      string targetVariable = problemData.TargetVariable;
141      IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
142      IEnumerable<int> rows = problemData.TrainingIndices;
143      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
144      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
145        throw new NotSupportedException("Random forest regression does not support NaN or infinity values in the input dataset.");
146
147      int info = 0;
148      alglib.decisionforest dForest = new alglib.decisionforest();
149      alglib.dfreport rep = new alglib.dfreport(); ;
150      int nRows = inputMatrix.GetLength(0);
151      int nColumns = inputMatrix.GetLength(1);
152      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
153      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
154
155      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, 1, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
156      if (info != 1) throw new ArgumentException("Error in calculation of random forest regression solution");
157
158      rmsError = rep.rmserror;
159      avgRelError = rep.avgrelerror;
160      outOfBagAvgRelError = rep.oobavgrelerror;
161      outOfBagRmsError = rep.oobrmserror;
162
163      return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables));
164    }
165    #endregion
166  }
167}
Note: See TracBrowser for help on using the repository browser.