Changeset 8786 for trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs
- Timestamp:
- 10/11/12 10:44:57 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs
r8139 r8786 26 26 using HeuristicLab.Core; 27 27 using HeuristicLab.Data; 28 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;29 28 using HeuristicLab.Optimization; 29 using HeuristicLab.Parameters; 30 30 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 31 31 using HeuristicLab.Problems.DataAnalysis; 32 using HeuristicLab.Problems.DataAnalysis.Symbolic;33 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;34 using HeuristicLab.Parameters;35 32 36 33 namespace HeuristicLab.Algorithms.DataAnalysis { … … 45 42 private const string NumberOfTreesParameterName = "Number of trees"; 46 43 private const string RParameterName = "R"; 44 private const string MParameterName = "M"; 45 private const string SeedParameterName = "Seed"; 46 private const string SetSeedRandomlyParameterName = "SetSeedRandomly"; 47 47 48 #region parameter properties 48 public I ValueParameter<IntValue> NumberOfTreesParameter {49 get { return (I ValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }49 public IFixedValueParameter<IntValue> NumberOfTreesParameter { 50 get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; } 50 51 } 51 public IValueParameter<DoubleValue> RParameter { 52 get { return (IValueParameter<DoubleValue>)Parameters[RParameterName]; } 52 public IFixedValueParameter<DoubleValue> RParameter { 53 get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; } 54 } 55 public IFixedValueParameter<DoubleValue> MParameter { 56 get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; } 57 } 58 public IFixedValueParameter<IntValue> SeedParameter { 59 get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; } 60 } 61 public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter { 62 get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; } 53 63 } 54 64 #endregion … … 62 72 set { RParameter.Value.Value = value; } 63 73 } 74 public double M { 75 get { return MParameter.Value.Value; } 76 set { MParameter.Value.Value = value; } 77 } 78 public int Seed { 79 get { return SeedParameter.Value.Value; } 80 set { SeedParameter.Value.Value = value; } 81 } 82 public bool SetSeedRandomly { 83 get { return SetSeedRandomlyParameter.Value.Value; } 84 set { SetSeedRandomlyParameter.Value.Value = value; } 85 } 64 86 #endregion 65 87 [StorableConstructor] … … 68 90 : base(original, cloner) { 69 91 } 92 70 93 public RandomForestRegression() 71 94 : base() { 72 95 Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50))); 73 96 Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3))); 97 Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5))); 98 Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0))); 99 Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true))); 74 100 Problem = new RegressionProblem(); 75 101 } 102 76 103 [StorableHook(HookType.AfterDeserialization)] 77 private void AfterDeserialization() { } 104 private void AfterDeserialization() { 105 if (!Parameters.ContainsKey(MParameterName)) 106 Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5))); 107 if (!Parameters.ContainsKey(SeedParameterName)) 108 Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0))); 109 if (!Parameters.ContainsKey((SetSeedRandomlyParameterName))) 110 Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true))); 111 } 78 112 79 113 public override IDeepCloneable Clone(Cloner cloner) { … … 84 118 protected override void Run() { 85 119 double rmsError, avgRelError, outOfBagRmsError, outOfBagAvgRelError; 86 var solution = CreateRandomForestRegressionSolution(Problem.ProblemData, NumberOfTrees, R, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 120 if (SetSeedRandomly) Seed = new System.Random().Next(); 121 122 var solution = CreateRandomForestRegressionSolution(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 87 123 Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution)); 88 124 Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError))); … … 92 128 } 93 129 94 public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, 130 public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 95 131 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) { 132 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1."); 133 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1."); 134 135 lock (alglib.math.rndobject) { 136 alglib.math.rndobject = new System.Random(seed); 137 } 138 96 139 Dataset dataset = problemData.Dataset; 97 140 string targetVariable = problemData.TargetVariable; … … 102 145 throw new NotSupportedException("Random forest regression does not support NaN or infinity values in the input dataset."); 103 146 147 int info = 0; 148 alglib.decisionforest dForest = new alglib.decisionforest(); 149 alglib.dfreport rep = new alglib.dfreport(); ; 150 int nRows = inputMatrix.GetLength(0); 151 int nColumns = inputMatrix.GetLength(1); 152 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1); 153 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); 104 154 105 alglib.decisionforest dforest; 106 alglib.dfreport rep; 107 int nRows = inputMatrix.GetLength(0); 108 109 int info; 110 alglib.dfbuildrandomdecisionforest(inputMatrix, nRows, allowedInputVariables.Count(), 1, nTrees, r, out info, out dforest, out rep); 155 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, 1, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); 111 156 if (info != 1) throw new ArgumentException("Error in calculation of random forest regression solution"); 112 157 … … 116 161 outOfBagRmsError = rep.oobrmserror; 117 162 118 return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), new RandomForestModel(d forest, targetVariable, allowedInputVariables));163 return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables)); 119 164 } 120 165 #endregion
Note: See TracChangeset
for help on using the changeset viewer.