Changeset 11443


Ignore:
Timestamp:
10/10/14 13:58:19 (8 years ago)
Author:
bburlacu
Message:

#2237: Made random forest parameters serializable (by deriving from ParameterCollection).

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs

    r11426 r11443  
    3030using HeuristicLab.Core;
    3131using HeuristicLab.Data;
     32using HeuristicLab.Parameters;
     33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3234using HeuristicLab.Problems.DataAnalysis;
    3335using HeuristicLab.Random;
    3436
    3537namespace HeuristicLab.Algorithms.DataAnalysis {
    36   public class RFParameter : ICloneable {
    37     public double n; // number of trees
    38     public double m;
    39     public double r;
    40 
    41     public object Clone() { return new RFParameter { n = this.n, m = this.m, r = this.r }; }
     38  [Item("RFParameter", "A random forest parameter collection")]
     39  [StorableClass]
     40  public class RFParameter : ParameterCollection {
     41    public RFParameter() {
     42      base.Add(new FixedValueParameter<IntValue>("N", "The number of random forest trees", new IntValue(50)));
     43      base.Add(new FixedValueParameter<DoubleValue>("M", "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.1)));
     44      base.Add(new FixedValueParameter<DoubleValue>("R", "The ratio of the training set that will be used in the construction of individual trees (0<r<=1)", new DoubleValue(0.1)));
     45    }
     46
     47    [StorableConstructor]
     48    private RFParameter(bool deserializing)
     49      : base(deserializing) {
     50    }
     51
     52    private RFParameter(RFParameter original, Cloner cloner)
     53      : base(original, cloner) {
     54      this.N = original.N;
     55      this.R = original.R;
     56      this.M = original.M;
     57    }
     58
     59    public override IDeepCloneable Clone(Cloner cloner) {
     60      return new RFParameter(this, cloner);
     61    }
     62
     63    private IFixedValueParameter<IntValue> NParameter {
     64      get { return (IFixedValueParameter<IntValue>)base["N"]; }
     65    }
     66
     67    private IFixedValueParameter<DoubleValue> RParameter {
     68      get { return (IFixedValueParameter<DoubleValue>)base["R"]; }
     69    }
     70
     71    private IFixedValueParameter<DoubleValue> MParameter {
     72      get { return (IFixedValueParameter<DoubleValue>)base["M"]; }
     73    }
     74
     75    public int N {
     76      get { return NParameter.Value.Value; }
     77      set { NParameter.Value.Value = value; }
     78    }
     79
     80    public double R {
     81      get { return RParameter.Value.Value; }
     82      set { RParameter.Value.Value = value; }
     83    }
     84
     85    public double M {
     86      get { return MParameter.Value.Value; }
     87      set { MParameter.Value.Value = value; }
     88    }
    4289  }
    4390
     
    64111      avgTestMse /= partitions.Length;
    65112    }
     113
    66114    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
    67115      avgTestAccuracy = 0;
     
    96144        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
    97145        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
    98         RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
     146        RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
    99147
    100148        lock (locker) {
     
    120168        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
    121169        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
    122         RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed,
     170        RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed,
    123171                                                                out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
    124172
     
    135183    public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    136184      DoubleValue mse = new DoubleValue(Double.MaxValue);
    137       RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };
     185      RFParameter bestParameter = new RFParameter();
    138186
    139187      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     
    148196          setters[i](parameters, parameterValues[i]);
    149197        }
    150         CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testMSE);
     198        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testMSE);
    151199
    152200        lock (locker) {
     
    162210    public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    163211      DoubleValue accuracy = new DoubleValue(0);
    164       RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };
     212      RFParameter bestParameter = new RFParameter();
    165213
    166214      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     
    175223          setters[i](parameters, parameterValues[i]);
    176224        }
    177         CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testAccuracy);
     225        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testAccuracy);
    178226
    179227        lock (locker) {
     
    256304      var targetExp = Expression.Parameter(typeof(RFParameter));
    257305      var valueExp = Expression.Parameter(typeof(double));
    258       var fieldExp = Expression.Field(targetExp, field);
     306      var fieldExp = Expression.Property(targetExp, field);
    259307      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
    260308      var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
Note: See TracChangeset for help on using the changeset viewer.