Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/11/12 10:44:57 (12 years ago)
Author:
mkommend
Message:

#1968: Added seed and m parameter to random forest modeling.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs

    r8139 r8786  
    2626using HeuristicLab.Core;
    2727using HeuristicLab.Data;
    28 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2928using HeuristicLab.Optimization;
     29using HeuristicLab.Parameters;
    3030using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3131using HeuristicLab.Problems.DataAnalysis;
    32 using HeuristicLab.Problems.DataAnalysis.Symbolic;
    33 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    34 using HeuristicLab.Parameters;
    3532
    3633namespace HeuristicLab.Algorithms.DataAnalysis {
     
    4542    private const string NumberOfTreesParameterName = "Number of trees";
    4643    private const string RParameterName = "R";
     44    private const string MParameterName = "M";
     45    private const string SeedParameterName = "Seed";
     46    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
     47
    4748    #region parameter properties
    48     public IValueParameter<IntValue> NumberOfTreesParameter {
    49       get { return (IValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
     49    public IFixedValueParameter<IntValue> NumberOfTreesParameter {
     50      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfTreesParameterName]; }
    5051    }
    51     public IValueParameter<DoubleValue> RParameter {
    52       get { return (IValueParameter<DoubleValue>)Parameters[RParameterName]; }
     52    public IFixedValueParameter<DoubleValue> RParameter {
     53      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
     54    }
     55    public IFixedValueParameter<DoubleValue> MParameter {
     56      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
     57    }
     58    public IFixedValueParameter<IntValue> SeedParameter {
     59      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
     60    }
     61    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
     62      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    5363    }
    5464    #endregion
     
    6272      set { RParameter.Value.Value = value; }
    6373    }
     74    public double M {
     75      get { return MParameter.Value.Value; }
     76      set { MParameter.Value.Value = value; }
     77    }
     78    public int Seed {
     79      get { return SeedParameter.Value.Value; }
     80      set { SeedParameter.Value.Value = value; }
     81    }
     82    public bool SetSeedRandomly {
     83      get { return SetSeedRandomlyParameter.Value.Value; }
     84      set { SetSeedRandomlyParameter.Value.Value = value; }
     85    }
    6486    #endregion
     87
    6588    [StorableConstructor]
    6689    private RandomForestClassification(bool deserializing) : base(deserializing) { }
     
    6891      : base(original, cloner) {
    6992    }
     93
    7094    public RandomForestClassification()
    7195      : base() {
    7296      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfTreesParameterName, "The number of trees in the forest. Should be between 50 and 100", new IntValue(50)));
    7397      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve good generalization error.", new DoubleValue(0.3)));
     98      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
     99      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
     100      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    74101      Problem = new ClassificationProblem();
    75102    }
     103
    76104    [StorableHook(HookType.AfterDeserialization)]
    77     private void AfterDeserialization() { }
     105    private void AfterDeserialization() {
     106      if (!Parameters.ContainsKey(MParameterName))
     107        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
     108      if (!Parameters.ContainsKey(SeedParameterName))
     109        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
     110      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
     111        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
     112    }
    78113
    79114    public override IDeepCloneable Clone(Cloner cloner) {
     
    84119    protected override void Run() {
    85120      double rmsError, relClassificationError, outOfBagRmsError, outOfBagRelClassificationError;
    86       var solution = CreateRandomForestClassificationSolution(Problem.ProblemData, NumberOfTrees, R, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
     121      if (SetSeedRandomly) Seed = new System.Random().Next();
     122
     123      var solution = CreateRandomForestClassificationSolution(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    87124      Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
    88125      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
     
    92129    }
    93130
    94     public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r,
     131    public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    95132      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
     133      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1.");
     134      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1.");
     135
     136      lock (alglib.math.rndobject) {
     137        alglib.math.rndobject = new System.Random(seed);
     138      }
     139
    96140      Dataset dataset = problemData.Dataset;
    97141      string targetVariable = problemData.TargetVariable;
     
    102146        throw new NotSupportedException("Random forest classification does not support NaN or infinity values in the input dataset.");
    103147
     148      int info = 0;
     149      alglib.decisionforest dForest = new alglib.decisionforest();
     150      alglib.dfreport rep = new alglib.dfreport(); ;
     151      int nRows = inputMatrix.GetLength(0);
     152      int nColumns = inputMatrix.GetLength(1);
     153      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
     154      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
    104155
    105       alglib.decisionforest dforest;
    106       alglib.dfreport rep;
    107       int nRows = inputMatrix.GetLength(0);
    108       int nCols = inputMatrix.GetLength(1);
    109       int info;
    110       double[] classValues = dataset.GetDoubleValues(targetVariable).Distinct().OrderBy(x => x).ToArray();
    111       int nClasses = classValues.Count();
     156
     157      double[] classValues = problemData.ClassValues.ToArray();
     158      int nClasses = problemData.Classes;
    112159      // map original class values to values [0..nClasses-1]
    113160      Dictionary<double, double> classIndices = new Dictionary<double, double>();
     
    116163      }
    117164      for (int row = 0; row < nRows; row++) {
    118         inputMatrix[row, nCols - 1] = classIndices[inputMatrix[row, nCols - 1]];
     165        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
    119166      }
    120       // execute random forest algorithm
    121       alglib.dfbuildrandomdecisionforest(inputMatrix, nRows, nCols - 1, nClasses, nTrees, r, out info, out dforest, out rep);
     167      // execute random forest algorithm     
     168      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
    122169      if (info != 1) throw new ArgumentException("Error in calculation of random forest classification solution");
    123170
     
    126173      relClassificationError = rep.relclserror;
    127174      outOfBagRelClassificationError = rep.oobrelclserror;
    128       return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), new RandomForestModel(dforest, targetVariable, allowedInputVariables, classValues));
     175      return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables, classValues));
    129176    }
    130177    #endregion
Note: See TracChangeset for help on using the changeset viewer.