Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/11/12 10:44:57 (12 years ago)
Author:
mkommend
Message:

#1968: Added seed and m parameter to random forest modeling.

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs

    r8139 r8786  
    2626using HeuristicLab.Core;
    2727using HeuristicLab.Data;
    28 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2928using HeuristicLab.Optimization;
     29using HeuristicLab.Parameters;
    3030using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3131using HeuristicLab.Problems.DataAnalysis;
    32 using HeuristicLab.Problems.DataAnalysis.Symbolic;
    33 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    34 using HeuristicLab.Parameters;
    3532
    3633namespace HeuristicLab.Algorithms.DataAnalysis {
     
    4542    private const string NumberOfTreesParameterName = "Number of trees";
    4643    private const string RParameterName = "R";
     44    private const string MParameterName = "M";
     45    private const string SeedParameterName = "Seed";
     46    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
     47
    4748    #region parameter properties
    48     public IValueParameter<IntValue> NumberOfTreesParameter {
    49       get { return (IValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
     49    public IFixedValueParameter<IntValue> NumberOfTreesParameter {
     50      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
    5051    }
    51     public IValueParameter<DoubleValue> RParameter {
    52       get { return (IValueParameter<DoubleValue>)Parameters[RParameterName]; }
     52    public IFixedValueParameter<DoubleValue> RParameter {
     53      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
     54    }
     55    public IFixedValueParameter<DoubleValue> MParameter {
     56      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
     57    }
     58    public IFixedValueParameter<IntValue> SeedParameter {
     59      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
     60    }
     61    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
     62      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    5363    }
    5464    #endregion
     
    6272      set { RParameter.Value.Value = value; }
    6373    }
     74    public double M {
     75      get { return MParameter.Value.Value; }
     76      set { MParameter.Value.Value = value; }
     77    }
     78    public int Seed {
     79      get { return SeedParameter.Value.Value; }
     80      set { SeedParameter.Value.Value = value; }
     81    }
     82    public bool SetSeedRandomly {
     83      get { return SetSeedRandomlyParameter.Value.Value; }
     84      set { SetSeedRandomlyParameter.Value.Value = value; }
     85    }
    6486    #endregion
     87
    6588    [StorableConstructor]
    6689    private RandomForestClassification(bool deserializing) : base(deserializing) { }
     
    6891      : base(original, cloner) {
    6992    }
     93
    7094    public RandomForestClassification()
    7195      : base() {
    7296      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50)));
    7397      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3)));
     98      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
     99      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
     100      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    74101      Problem = new ClassificationProblem();
    75102    }
     103
    76104    [StorableHook(HookType.AfterDeserialization)]
    77     private void AfterDeserialization() { }
     105    private void AfterDeserialization() {
     106      if (!Parameters.ContainsKey(MParameterName))
     107        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
     108      if (!Parameters.ContainsKey(SeedParameterName))
     109        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
     110      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
     111        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
     112    }
    78113
    79114    public override IDeepCloneable Clone(Cloner cloner) {
     
    84119    protected override void Run() {
    85120      double rmsError, relClassificationError, outOfBagRmsError, outOfBagRelClassificationError;
    86       var solution = CreateRandomForestClassificationSolution(Problem.ProblemData, NumberOfTrees, R, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
     121      if (SetSeedRandomly) Seed = new System.Random().Next();
     122
     123      var solution = CreateRandomForestClassificationSolution(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    87124      Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
    88125      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
     
    92129    }
    93130
    94     public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r,
     131    public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    95132      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
     133      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1.");
     134      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1.");
     135
     136      lock (alglib.math.rndobject) {
     137        alglib.math.rndobject = new System.Random(seed);
     138      }
     139
    96140      Dataset dataset = problemData.Dataset;
    97141      string targetVariable = problemData.TargetVariable;
     
    102146        throw new NotSupportedException("Random forest classification does not support NaN or infinity values in the input dataset.");
    103147
     148      int info = 0;
     149      alglib.decisionforest dForest = new alglib.decisionforest();
     150      alglib.dfreport rep = new alglib.dfreport(); ;
     151      int nRows = inputMatrix.GetLength(0);
     152      int nColumns = inputMatrix.GetLength(1);
     153      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
     154      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
    104155
    105       alglib.decisionforest dforest;
    106       alglib.dfreport rep;
    107       int nRows = inputMatrix.GetLength(0);
    108       int nCols = inputMatrix.GetLength(1);
    109       int info;
    110       double[] classValues = dataset.GetDoubleValues(targetVariable).Distinct().OrderBy(x => x).ToArray();
    111       int nClasses = classValues.Count();
     156
     157      double[] classValues = problemData.ClassValues.ToArray();
     158      int nClasses = problemData.Classes;
    112159      // map original class values to values [0..nClasses-1]
    113160      Dictionary<double, double> classIndices = new Dictionary<double, double>();
     
    116163      }
    117164      for (int row = 0; row < nRows; row++) {
    118         inputMatrix[row, nCols - 1] = classIndices[inputMatrix[row, nCols - 1]];
     165        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
    119166      }
    120       // execute random forest algorithm
    121       alglib.dfbuildrandomdecisionforest(inputMatrix, nRows, nCols - 1, nClasses, nTrees, r, out info, out dforest, out rep);
     167      // execute random forest algorithm     
     168      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
    122169      if (info != 1) throw new ArgumentException("Error in calculation of random forest classification solution");
    123170
     
    126173      relClassificationError = rep.relclserror;
    127174      outOfBagRelClassificationError = rep.oobrelclserror;
    128       return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), new RandomForestModel(dforest, targetVariable, allowedInputVariables, classValues));
     175      return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables, classValues));
    129176    }
    130177    #endregion
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs

    r8139 r8786  
    2626using HeuristicLab.Core;
    2727using HeuristicLab.Data;
    28 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2928using HeuristicLab.Optimization;
     29using HeuristicLab.Parameters;
    3030using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3131using HeuristicLab.Problems.DataAnalysis;
    32 using HeuristicLab.Problems.DataAnalysis.Symbolic;
    33 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    34 using HeuristicLab.Parameters;
    3532
    3633namespace HeuristicLab.Algorithms.DataAnalysis {
     
    4542    private const string NumberOfTreesParameterName = "Number of trees";
    4643    private const string RParameterName = "R";
     44    private const string MParameterName = "M";
     45    private const string SeedParameterName = "Seed";
     46    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
     47
    4748    #region parameter properties
    48     public IValueParameter<IntValue> NumberOfTreesParameter {
    49       get { return (IValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
     49    public IFixedValueParameter<IntValue> NumberOfTreesParameter {
     50      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
    5051    }
    51     public IValueParameter<DoubleValue> RParameter {
    52       get { return (IValueParameter<DoubleValue>)Parameters[RParameterName]; }
     52    public IFixedValueParameter<DoubleValue> RParameter {
     53      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
     54    }
     55    public IFixedValueParameter<DoubleValue> MParameter {
     56      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
     57    }
     58    public IFixedValueParameter<IntValue> SeedParameter {
     59      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
     60    }
     61    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
     62      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    5363    }
    5464    #endregion
     
    6272      set { RParameter.Value.Value = value; }
    6373    }
     74    public double M {
     75      get { return MParameter.Value.Value; }
     76      set { MParameter.Value.Value = value; }
     77    }
     78    public int Seed {
     79      get { return SeedParameter.Value.Value; }
     80      set { SeedParameter.Value.Value = value; }
     81    }
     82    public bool SetSeedRandomly {
     83      get { return SetSeedRandomlyParameter.Value.Value; }
     84      set { SetSeedRandomlyParameter.Value.Value = value; }
     85    }
    6486    #endregion
    6587    [StorableConstructor]
     
    6890      : base(original, cloner) {
    6991    }
     92
    7093    public RandomForestRegression()
    7194      : base() {
    7295      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50)));
    7396      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3)));
     97      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
     98      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
     99      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    74100      Problem = new RegressionProblem();
    75101    }
     102
    76103    [StorableHook(HookType.AfterDeserialization)]
    77     private void AfterDeserialization() { }
     104    private void AfterDeserialization() {
     105      if (!Parameters.ContainsKey(MParameterName))
     106        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
     107      if (!Parameters.ContainsKey(SeedParameterName))
     108        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
     109      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
     110        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
     111    }
    78112
    79113    public override IDeepCloneable Clone(Cloner cloner) {
     
    84118    protected override void Run() {
    85119      double rmsError, avgRelError, outOfBagRmsError, outOfBagAvgRelError;
    86       var solution = CreateRandomForestRegressionSolution(Problem.ProblemData, NumberOfTrees, R, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     120      if (SetSeedRandomly) Seed = new System.Random().Next();
     121
     122      var solution = CreateRandomForestRegressionSolution(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
    87123      Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
    88124      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
     
    92128    }
    93129
    94     public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r,
     130    public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
    95131      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
     132      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1.");
     133      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1.");
     134
     135      lock (alglib.math.rndobject) {
     136        alglib.math.rndobject = new System.Random(seed);
     137      }
     138
    96139      Dataset dataset = problemData.Dataset;
    97140      string targetVariable = problemData.TargetVariable;
     
    102145        throw new NotSupportedException("Random forest regression does not support NaN or infinity values in the input dataset.");
    103146
     147      int info = 0;
     148      alglib.decisionforest dForest = new alglib.decisionforest();
     149      alglib.dfreport rep = new alglib.dfreport(); ;
     150      int nRows = inputMatrix.GetLength(0);
     151      int nColumns = inputMatrix.GetLength(1);
     152      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
     153      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
    104154
    105       alglib.decisionforest dforest;
    106       alglib.dfreport rep;
    107       int nRows = inputMatrix.GetLength(0);
    108 
    109       int info;
    110       alglib.dfbuildrandomdecisionforest(inputMatrix, nRows, allowedInputVariables.Count(), 1, nTrees, r, out info, out dforest, out rep);
     155      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, 1, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
    111156      if (info != 1) throw new ArgumentException("Error in calculation of random forest regression solution");
    112157
     
    116161      outOfBagRmsError = rep.oobrmserror;
    117162
    118       return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), new RandomForestModel(dforest, targetVariable, allowedInputVariables));
     163      return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables));
    119164    }
    120165    #endregion
Note: See TracChangeset for help on using the changeset viewer.