[6240] | 1 | #region License Information
|
---|
| 2 | /* HeuristicLab
|
---|
[17180] | 3 | * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
|
---|
[6240] | 4 | *
|
---|
| 5 | * This file is part of HeuristicLab.
|
---|
| 6 | *
|
---|
| 7 | * HeuristicLab is free software: you can redistribute it and/or modify
|
---|
| 8 | * it under the terms of the GNU General Public License as published by
|
---|
| 9 | * the Free Software Foundation, either version 3 of the License, or
|
---|
| 10 | * (at your option) any later version.
|
---|
| 11 | *
|
---|
| 12 | * HeuristicLab is distributed in the hope that it will be useful,
|
---|
| 13 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
---|
| 14 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
---|
| 15 | * GNU General Public License for more details.
|
---|
| 16 | *
|
---|
| 17 | * You should have received a copy of the GNU General Public License
|
---|
| 18 | * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
|
---|
| 19 | */
|
---|
| 20 | #endregion
|
---|
| 21 |
|
---|
[17154] | 22 | using System.Collections.Generic;
|
---|
| 23 | using System.Linq;
|
---|
[14523] | 24 | using System.Threading;
|
---|
[17154] | 25 | using HEAL.Attic;
|
---|
| 26 | using HeuristicLab.Algorithms.DataAnalysis.RandomForest;
|
---|
[6240] | 27 | using HeuristicLab.Common;
|
---|
| 28 | using HeuristicLab.Core;
|
---|
| 29 | using HeuristicLab.Data;
|
---|
| 30 | using HeuristicLab.Optimization;
|
---|
[8786] | 31 | using HeuristicLab.Parameters;
|
---|
[6240] | 32 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 33 |
|
---|
| 34 | namespace HeuristicLab.Algorithms.DataAnalysis {
|
---|
| 35 | /// <summary>
|
---|
| 36 | /// Random forest regression data analysis algorithm.
|
---|
| 37 | /// </summary>
|
---|
[13238] | 38 | [Item("Random Forest Regression (RF)", "Random forest regression data analysis algorithm (wrapper for ALGLIB).")]
|
---|
[12504] | 39 | [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
|
---|
[16565] | 40 | [StorableType("721CE0EB-82AF-4E49-9900-48E1C67B5E53")]
|
---|
[6240] | 41 | public sealed class RandomForestRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
|
---|
| 42 | private const string RandomForestRegressionModelResultName = "Random forest regression solution";
|
---|
| 43 | private const string NumberOfTreesParameterName = "Number of trees";
|
---|
| 44 | private const string RParameterName = "R";
|
---|
[8786] | 45 | private const string MParameterName = "M";
|
---|
| 46 | private const string SeedParameterName = "Seed";
|
---|
| 47 | private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
|
---|
[17154] | 48 | private const string ModelCreationParameterName = "ModelCreation";
|
---|
[8786] | 49 |
|
---|
[6240] | 50 | #region parameter properties
|
---|
[8786] | 51 | public IFixedValueParameter<IntValue> NumberOfTreesParameter {
|
---|
| 52 | get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
|
---|
[6240] | 53 | }
|
---|
[8786] | 54 | public IFixedValueParameter<DoubleValue> RParameter {
|
---|
| 55 | get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
|
---|
[6240] | 56 | }
|
---|
[8786] | 57 | public IFixedValueParameter<DoubleValue> MParameter {
|
---|
| 58 | get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
|
---|
| 59 | }
|
---|
| 60 | public IFixedValueParameter<IntValue> SeedParameter {
|
---|
| 61 | get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
|
---|
| 62 | }
|
---|
| 63 | public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
|
---|
| 64 | get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
|
---|
| 65 | }
|
---|
[17154] | 66 | private IFixedValueParameter<EnumValue<ModelCreation>> ModelCreationParameter {
|
---|
| 67 | get { return (IFixedValueParameter<EnumValue<ModelCreation>>)Parameters[ModelCreationParameterName]; }
|
---|
[13204] | 68 | }
|
---|
[6240] | 69 | #endregion
|
---|
| 70 | #region properties
|
---|
| 71 | public int NumberOfTrees {
|
---|
| 72 | get { return NumberOfTreesParameter.Value.Value; }
|
---|
| 73 | set { NumberOfTreesParameter.Value.Value = value; }
|
---|
| 74 | }
|
---|
| 75 | public double R {
|
---|
| 76 | get { return RParameter.Value.Value; }
|
---|
| 77 | set { RParameter.Value.Value = value; }
|
---|
| 78 | }
|
---|
[8786] | 79 | public double M {
|
---|
| 80 | get { return MParameter.Value.Value; }
|
---|
| 81 | set { MParameter.Value.Value = value; }
|
---|
| 82 | }
|
---|
| 83 | public int Seed {
|
---|
| 84 | get { return SeedParameter.Value.Value; }
|
---|
| 85 | set { SeedParameter.Value.Value = value; }
|
---|
| 86 | }
|
---|
| 87 | public bool SetSeedRandomly {
|
---|
| 88 | get { return SetSeedRandomlyParameter.Value.Value; }
|
---|
| 89 | set { SetSeedRandomlyParameter.Value.Value = value; }
|
---|
| 90 | }
|
---|
[17154] | 91 | public ModelCreation ModelCreation {
|
---|
| 92 | get { return ModelCreationParameter.Value.Value; }
|
---|
| 93 | set { ModelCreationParameter.Value.Value = value; }
|
---|
[13204] | 94 | }
|
---|
[6240] | 95 | #endregion
|
---|
| 96 | [StorableConstructor]
|
---|
[16565] | 97 | private RandomForestRegression(StorableConstructorFlag _) : base(_) { }
|
---|
[6240] | 98 | private RandomForestRegression(RandomForestRegression original, Cloner cloner)
|
---|
| 99 | : base(original, cloner) {
|
---|
| 100 | }
|
---|
[8786] | 101 |
|
---|
[6240] | 102 | public RandomForestRegression()
|
---|
| 103 | : base() {
|
---|
| 104 | Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50)));
|
---|
| 105 | Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3)));
|
---|
[8786] | 106 | Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
|
---|
| 107 | Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
|
---|
| 108 | Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
|
---|
[17154] | 109 | Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
|
---|
| 110 | Parameters[ModelCreationParameterName].Hidden = true;
|
---|
[13204] | 111 |
|
---|
[6240] | 112 | Problem = new RegressionProblem();
|
---|
| 113 | }
|
---|
[8786] | 114 |
|
---|
[6240] | 115 | [StorableHook(HookType.AfterDeserialization)]
|
---|
[8786] | 116 | private void AfterDeserialization() {
|
---|
[13204] | 117 | // BackwardsCompatibility3.3
|
---|
| 118 | #region Backwards compatible code, remove with 3.4
|
---|
[8786] | 119 | if (!Parameters.ContainsKey(MParameterName))
|
---|
| 120 | Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
|
---|
| 121 | if (!Parameters.ContainsKey(SeedParameterName))
|
---|
| 122 | Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
|
---|
| 123 | if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
|
---|
| 124 | Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
|
---|
[17154] | 125 |
|
---|
| 126 | // parameter type has been changed
|
---|
| 127 | if (Parameters.ContainsKey("CreateSolution")) {
|
---|
| 128 | var createSolutionParam = Parameters["CreateSolution"] as FixedValueParameter<BoolValue>;
|
---|
| 129 | Parameters.Remove(createSolutionParam);
|
---|
| 130 |
|
---|
| 131 | ModelCreation value = createSolutionParam.Value.Value ? ModelCreation.Model : ModelCreation.QualityOnly;
|
---|
| 132 | Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value)));
|
---|
| 133 | Parameters[ModelCreationParameterName].Hidden = true;
|
---|
| 134 | } else if (!Parameters.ContainsKey(ModelCreationParameterName)) {
|
---|
| 135 | // very old version contains neither ModelCreationParameter nor CreateSolutionParameter
|
---|
| 136 | Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
|
---|
| 137 | Parameters[ModelCreationParameterName].Hidden = true;
|
---|
[13204] | 138 | }
|
---|
| 139 | #endregion
|
---|
[8786] | 140 | }
|
---|
[6240] | 141 |
|
---|
| 142 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
| 143 | return new RandomForestRegression(this, cloner);
|
---|
| 144 | }
|
---|
| 145 |
|
---|
| 146 | #region random forest
|
---|
[14523] | 147 | protected override void Run(CancellationToken cancellationToken) {
|
---|
[6240] | 148 | double rmsError, avgRelError, outOfBagRmsError, outOfBagAvgRelError;
|
---|
[16071] | 149 | if (SetSeedRandomly) Seed = Random.RandomSeedGenerator.GetSeed();
|
---|
[13204] | 150 | var model = CreateRandomForestRegressionModel(Problem.ProblemData, NumberOfTrees, R, M, Seed,
|
---|
| 151 | out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
|
---|
[8786] | 152 |
|
---|
[6240] | 153 | Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
|
---|
[6241] | 154 | Results.Add(new Result("Average relative error", "The average of relative errors of the random forest regression solution on the training set.", new PercentValue(avgRelError)));
|
---|
| 155 | Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
|
---|
| 156 | Results.Add(new Result("Average relative error (out-of-bag)", "The out-of-bag average of relative errors of the random forest regression solution.", new PercentValue(outOfBagAvgRelError)));
|
---|
[13204] | 157 |
|
---|
[17154] | 158 | IRegressionSolution solution = null;
|
---|
| 159 | if (ModelCreation == ModelCreation.Model) {
|
---|
| 160 | solution = model.CreateRegressionSolution(Problem.ProblemData);
|
---|
| 161 | } else if (ModelCreation == ModelCreation.SurrogateModel) {
|
---|
| 162 | var problemData = Problem.ProblemData;
|
---|
| 163 | var surrogateModel = new RandomForestModelSurrogate(model, problemData.TargetVariable, problemData, Seed, NumberOfTrees, R, M);
|
---|
| 164 | solution = surrogateModel.CreateRegressionSolution(problemData);
|
---|
| 165 | }
|
---|
| 166 |
|
---|
| 167 | if (solution != null) {
|
---|
[13204] | 168 | Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
|
---|
| 169 | }
|
---|
[6240] | 170 | }
|
---|
| 171 |
|
---|
[17154] | 172 |
|
---|
[13204] | 173 | // keep for compatibility with old API
|
---|
| 174 | public static RandomForestRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
|
---|
[6240] | 175 | out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
|
---|
[13204] | 176 | var model = CreateRandomForestRegressionModel(problemData, nTrees, r, m, seed,
|
---|
| 177 | out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
|
---|
[13941] | 178 | return new RandomForestRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
|
---|
[6240] | 179 | }
|
---|
[13204] | 180 |
|
---|
[17154] | 181 | public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees,
|
---|
| 182 | double r, double m, int seed,
|
---|
| 183 | out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
|
---|
| 184 | var model = CreateRandomForestRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
|
---|
| 185 | return model;
|
---|
[13204] | 186 | }
|
---|
| 187 |
|
---|
[17154] | 188 | public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
|
---|
| 189 | out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
|
---|
| 190 |
|
---|
| 191 | var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
|
---|
| 192 | double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
|
---|
| 193 |
|
---|
| 194 | alglib.dfreport rep;
|
---|
| 195 | var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
|
---|
| 196 |
|
---|
| 197 | rmsError = rep.rmserror;
|
---|
| 198 | outOfBagRmsError = rep.oobrmserror;
|
---|
| 199 | avgRelError = rep.avgrelerror;
|
---|
| 200 | outOfBagAvgRelError = rep.oobavgrelerror;
|
---|
| 201 |
|
---|
| 202 | return new RandomForestModelFull(dForest, problemData.TargetVariable, problemData.AllowedInputVariables);
|
---|
| 203 | }
|
---|
| 204 |
|
---|
[6240] | 205 | #endregion
|
---|
| 206 | }
|
---|
| 207 | }
|
---|