source: trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs @ 16565

Last change on this file since 16565 was 16565, checked in by gkronber, 8 months ago

#2520: merged changes from PersistenceOverhaul branch (r16451:16564) into trunk

File size: 9.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Threading;
23using HeuristicLab.Common;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.Optimization;
27using HeuristicLab.Parameters;
28using HEAL.Attic;
29using HeuristicLab.Problems.DataAnalysis;
30
31namespace HeuristicLab.Algorithms.DataAnalysis {
32  /// <summary>
33  /// Random forest classification data analysis algorithm.
34  /// </summary>
35  [Item("Random Forest Classification (RF)", "Random forest classification data analysis algorithm (wrapper for ALGLIB).")]
36  [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 120)]
37  [StorableType("73070CC7-E85E-4851-9F26-C537AE1CC1C0")]
38  public sealed class RandomForestClassification : FixedDataAnalysisAlgorithm<IClassificationProblem> {
39    private const string RandomForestClassificationModelResultName = "Random forest classification solution";
40    private const string NumberOfTreesParameterName = "Number of trees";
41    private const string RParameterName = "R";
42    private const string MParameterName = "M";
43    private const string SeedParameterName = "Seed";
44    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
45    private const string CreateSolutionParameterName = "CreateSolution";
46
47    #region parameter properties
48    public IFixedValueParameter<IntValue> NumberOfTreesParameter {
49      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
50    }
51    public IFixedValueParameter<DoubleValue> RParameter {
52      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
53    }
54    public IFixedValueParameter<DoubleValue> MParameter {
55      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
56    }
57    public IFixedValueParameter<IntValue> SeedParameter {
58      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
59    }
60    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
61      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
62    }
63    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
64      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
65    }
66    #endregion
67    #region properties
68    public int NumberOfTrees {
69      get { return NumberOfTreesParameter.Value.Value; }
70      set { NumberOfTreesParameter.Value.Value = value; }
71    }
72    public double R {
73      get { return RParameter.Value.Value; }
74      set { RParameter.Value.Value = value; }
75    }
76    public double M {
77      get { return MParameter.Value.Value; }
78      set { MParameter.Value.Value = value; }
79    }
80    public int Seed {
81      get { return SeedParameter.Value.Value; }
82      set { SeedParameter.Value.Value = value; }
83    }
84    public bool SetSeedRandomly {
85      get { return SetSeedRandomlyParameter.Value.Value; }
86      set { SetSeedRandomlyParameter.Value.Value = value; }
87    }
88    public bool CreateSolution {
89      get { return CreateSolutionParameter.Value.Value; }
90      set { CreateSolutionParameter.Value.Value = value; }
91    }
92    #endregion
93
94    [StorableConstructor]
95    private RandomForestClassification(StorableConstructorFlag _) : base(_) { }
96    private RandomForestClassification(RandomForestClassification original, Cloner cloner)
97      : base(original, cloner) {
98    }
99
100    public RandomForestClassification()
101      : base() {
102      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50)));
103      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3)));
104      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
105      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
106      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
107      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
108      Parameters[CreateSolutionParameterName].Hidden = true;
109
110      Problem = new ClassificationProblem();
111    }
112
113    [StorableHook(HookType.AfterDeserialization)]
114    private void AfterDeserialization() {
115      // BackwardsCompatibility3.3
116      #region Backwards compatible code, remove with 3.4
117      if (!Parameters.ContainsKey(MParameterName))
118        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
119      if (!Parameters.ContainsKey(SeedParameterName))
120        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
121      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
122        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
123      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
124        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
125        Parameters[CreateSolutionParameterName].Hidden = true;
126      }
127      #endregion
128    }
129
130    public override IDeepCloneable Clone(Cloner cloner) {
131      return new RandomForestClassification(this, cloner);
132    }
133
134    #region random forest
135    protected override void Run(CancellationToken cancellationToken) {
136      double rmsError, relClassificationError, outOfBagRmsError, outOfBagRelClassificationError;
137      if (SetSeedRandomly) Seed = Random.RandomSeedGenerator.GetSeed();
138
139      var model = CreateRandomForestClassificationModel(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
140      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
141      Results.Add(new Result("Relative classification error", "Relative classification error of the random forest regression solution on the training set.", new PercentValue(relClassificationError)));
142      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
143      Results.Add(new Result("Relative classification error (out-of-bag)", "The out-of-bag relative classification error  of the random forest regression solution.", new PercentValue(outOfBagRelClassificationError)));
144
145      if (CreateSolution) {
146        var solution = new RandomForestClassificationSolution(model, (IClassificationProblemData)Problem.ProblemData.Clone());
147        Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
148      }
149    }
150
151    // keep for compatibility with old API
152    public static RandomForestClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
153      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
154      var model = CreateRandomForestClassificationModel(problemData, nTrees, r, m, seed,
155        out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
156      return new RandomForestClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
157    }
158
159    public static RandomForestModel CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
160      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
161      return RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed,
162       rmsError: out rmsError, relClassificationError: out relClassificationError, outOfBagRmsError: out outOfBagRmsError, outOfBagRelClassificationError: out outOfBagRelClassificationError);
163    }
164    #endregion
165  }
166}
Note: See TracBrowser for help on using the repository browser.