Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/11/12 10:44:57 (12 years ago)
Author:
mkommend
Message:

#1968: Added seed and m parameter to random forest modeling.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs

    r8139 r8786  
    2626using HeuristicLab.Core;
    2727using HeuristicLab.Data;
    28 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2928using HeuristicLab.Optimization;
     29using HeuristicLab.Parameters;
    3030using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3131using HeuristicLab.Problems.DataAnalysis;
    32 using HeuristicLab.Problems.DataAnalysis.Symbolic;
    33 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    34 using HeuristicLab.Parameters;
    3532
    3633namespace HeuristicLab.Algorithms.DataAnalysis {
     
    4542    private const string NumberOfTreesParameterName = "Number of trees";
    4643    private const string RParameterName = "R";
     44    private const string MParameterName = "M";
     45    private const string SeedParameterName = "Seed";
     46    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
     47
    4748    #region parameter properties
    48     public IValueParameter<IntValue> NumberOfTreesParameter {
    49       get { return (IValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
     49    public IFixedValueParameter<IntValue> NumberOfTreesParameter {
     50      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
    5051    }
    51     public IValueParameter<DoubleValue> RParameter {
    52       get { return (IValueParameter<DoubleValue>)Parameters[RParameterName]; }
     52    public IFixedValueParameter<DoubleValue> RParameter {
     53      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
     54    }
     55    public IFixedValueParameter<DoubleValue> MParameter {
     56      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
     57    }
     58    public IFixedValueParameter<IntValue> SeedParameter {
     59      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
     60    }
     61    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
     62      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    5363    }
    5464    #endregion
     
    6272      set { RParameter.Value.Value = value; }
    6373    }
     74    public double M {
     75      get { return MParameter.Value.Value; }
     76      set { MParameter.Value.Value = value; }
     77    }
     78    public int Seed {
     79      get { return SeedParameter.Value.Value; }
     80      set { SeedParameter.Value.Value = value; }
     81    }
     82    public bool SetSeedRandomly {
     83      get { return SetSeedRandomlyParameter.Value.Value; }
     84      set { SetSeedRandomlyParameter.Value.Value = value; }
     85    }
    6486    #endregion
    6587    [StorableConstructor]
     
    6890      : base(original, cloner) {
    6991    }
     92
    7093    public RandomForestRegression()
    7194      : base() {
    7295      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50)));
    7396      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3)));
     97      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
     98      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
     99      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    74100      Problem = new RegressionProblem();
    75101    }
     102
    76103    [StorableHook(HookType.AfterDeserialization)]
    77     private void AfterDeserialization() { }
     104    private void AfterDeserialization() {
     105      if (!Parameters.ContainsKey(MParameterName))
     106        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
     107      if (!Parameters.ContainsKey(SeedParameterName))
     108        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
     109      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
     110        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
     111    }
    78112
    79113    public override IDeepCloneable Clone(Cloner cloner) {
     
    84118    protected override void Run() {
    85119      double rmsError, avgRelError, outOfBagRmsError, outOfBagAvgRelError;
    86       var solution = CreateRandomForestRegressionSolution(Problem.ProblemData, NumberOfTrees, R, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     120      if (SetSeedRandomly) Seed = new System.Random().Next();
     121
     122      var solution = CreateRandomForestRegressionSolution(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
    87123      Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
    88124      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
     
    92128    }
    93129
    94     public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r,
     130    public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
    95131      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
     132      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1.");
     133      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1.");
     134
     135      lock (alglib.math.rndobject) {
     136        alglib.math.rndobject = new System.Random(seed);
     137      }
     138
    96139      Dataset dataset = problemData.Dataset;
    97140      string targetVariable = problemData.TargetVariable;
     
    102145        throw new NotSupportedException("Random forest regression does not support NaN or infinity values in the input dataset.");
    103146
     147      int info = 0;
     148      alglib.decisionforest dForest = new alglib.decisionforest();
     149      alglib.dfreport rep = new alglib.dfreport(); ;
     150      int nRows = inputMatrix.GetLength(0);
     151      int nColumns = inputMatrix.GetLength(1);
     152      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
     153      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
    104154
    105       alglib.decisionforest dforest;
    106       alglib.dfreport rep;
    107       int nRows = inputMatrix.GetLength(0);
    108 
    109       int info;
    110       alglib.dfbuildrandomdecisionforest(inputMatrix, nRows, allowedInputVariables.Count(), 1, nTrees, r, out info, out dforest, out rep);
     155      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, 1, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
    111156      if (info != 1) throw new ArgumentException("Error in calculation of random forest regression solution");
    112157
     
    116161      outOfBagRmsError = rep.oobrmserror;
    117162
    118       return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), new RandomForestModel(dforest, targetVariable, allowedInputVariables));
     163      return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables));
    119164    }
    120165    #endregion
Note: See TracChangeset for help on using the changeset viewer.