Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
11/27/14 11:23:37 (10 years ago)
Author:
jkarder
Message:

#2116: merged r10041-r11593 from trunk into branch

Location:
branches/Breadcrumbs
Files:
7 edited
1 copied

Legend:

Unmodified
Added
Removed
  • branches/Breadcrumbs

  • branches/Breadcrumbs/HeuristicLab.Algorithms.DataAnalysis

  • branches/Breadcrumbs/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs

    r9456 r11594  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    131131    public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    132132      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
    133       if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1.");
    134       if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1.");
    135 
    136       alglib.math.rndobject = new System.Random(seed);
    137 
    138       Dataset dataset = problemData.Dataset;
    139       string targetVariable = problemData.TargetVariable;
    140       IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
    141       IEnumerable<int> rows = problemData.TrainingIndices;
    142       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
    143       if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    144         throw new NotSupportedException("Random forest classification does not support NaN or infinity values in the input dataset.");
    145 
    146       int info = 0;
    147       alglib.decisionforest dForest = new alglib.decisionforest();
    148       alglib.dfreport rep = new alglib.dfreport(); ;
    149       int nRows = inputMatrix.GetLength(0);
    150       int nColumns = inputMatrix.GetLength(1);
    151       int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
    152       int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
    153 
    154 
    155       double[] classValues = problemData.ClassValues.ToArray();
    156       int nClasses = problemData.Classes;
    157       // map original class values to values [0..nClasses-1]
    158       Dictionary<double, double> classIndices = new Dictionary<double, double>();
    159       for (int i = 0; i < nClasses; i++) {
    160         classIndices[classValues[i]] = i;
    161       }
    162       for (int row = 0; row < nRows; row++) {
    163         inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
    164       }
    165       // execute random forest algorithm     
    166       alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
    167       if (info != 1) throw new ArgumentException("Error in calculation of random forest classification solution");
    168 
    169       rmsError = rep.rmserror;
    170       outOfBagRmsError = rep.oobrmserror;
    171       relClassificationError = rep.relclserror;
    172       outOfBagRelClassificationError = rep.oobrelclserror;
    173       return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables, classValues));
     133      var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
     134      return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model);
    174135    }
    175136    #endregion
  • branches/Breadcrumbs/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassificationSolution.cs

    r9456 r11594  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
  • branches/Breadcrumbs/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r9456 r11594  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    3535  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
    3636  public sealed class RandomForestModel : NamedItem, IRandomForestModel {
    37 
     37    // not persisted
    3838    private alglib.decisionforest randomForest;
    39     public alglib.decisionforest RandomForest {
    40       get { return randomForest; }
    41       set {
    42         if (value != randomForest) {
    43           if (value == null) throw new ArgumentNullException();
    44           randomForest = value;
    45           OnChanged(EventArgs.Empty);
    46         }
    47       }
    48     }
    49 
    50     [Storable]
    51     private string targetVariable;
    52     [Storable]
    53     private string[] allowedInputVariables;
     39    private alglib.decisionforest RandomForest {
     40      get {
     41        // recalculate lazily
     42        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
     43        return randomForest;
     44      }
     45    }
     46
     47    // instead of storing the data of the model itself
     48    // we instead only store data necessary to recalculate the same model lazily on demand
     49    [Storable]
     50    private int seed;
     51    [Storable]
     52    private IDataAnalysisProblemData originalTrainingData;
    5453    [Storable]
    5554    private double[] classValues;
     55    [Storable]
     56    private int nTrees;
     57    [Storable]
     58    private double r;
     59    [Storable]
     60    private double m;
     61
     62
    5663    [StorableConstructor]
    5764    private RandomForestModel(bool deserializing)
    5865      : base(deserializing) {
    59       if (deserializing)
    60         randomForest = new alglib.decisionforest();
     66      // for backwards compatibility (loading old solutions)
     67      randomForest = new alglib.decisionforest();
    6168    }
    6269    private RandomForestModel(RandomForestModel original, Cloner cloner)
     
    6774      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
    6875      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
    69       randomForest.innerobj.trees = (double[])original.randomForest.innerobj.trees.Clone();
    70       targetVariable = original.targetVariable;
    71       allowedInputVariables = (string[])original.allowedInputVariables.Clone();
    72       if (original.classValues != null)
    73         this.classValues = (double[])original.classValues.Clone();
    74     }
    75     public RandomForestModel(alglib.decisionforest randomForest, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
     76      // we assume that the trees array (double[]) is immutable in alglib
     77      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
     78
     79      // allowedInputVariables is immutable so we don't need to clone
     80      allowedInputVariables = original.allowedInputVariables;
     81
     82      // clone data which is necessary to rebuild the model
     83      this.seed = original.seed;
     84      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
     85      // classvalues is immutable so we don't need to clone
     86      this.classValues = original.classValues;
     87      this.nTrees = original.nTrees;
     88      this.r = original.r;
     89      this.m = original.m;
     90    }
     91
     92    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
     93    private RandomForestModel(alglib.decisionforest randomForest,
     94      int seed, IDataAnalysisProblemData originalTrainingData,
     95      int nTrees, double r, double m, double[] classValues = null)
    7696      : base() {
    7797      this.name = ItemName;
    7898      this.description = ItemDescription;
     99      // the model itself
    79100      this.randomForest = randomForest;
    80       this.targetVariable = targetVariable;
    81       this.allowedInputVariables = allowedInputVariables.ToArray();
    82       if (classValues != null)
    83         this.classValues = (double[])classValues.Clone();
     101      // data which is necessary for recalculation of the model
     102      this.seed = seed;
     103      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
     104      this.classValues = classValues;
     105      this.nTrees = nTrees;
     106      this.r = r;
     107      this.m = m;
    84108    }
    85109
     
    88112    }
    89113
     114    private void RecalculateModel() {
     115      double rmsError, oobRmsError, relClassError, oobRelClassError;
     116      var regressionProblemData = originalTrainingData as IRegressionProblemData;
     117      var classificationProblemData = originalTrainingData as IClassificationProblemData;
     118      if (regressionProblemData != null) {
     119        var model = CreateRegressionModel(regressionProblemData,
     120                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
     121                                              out relClassError, out oobRelClassError);
     122        randomForest = model.randomForest;
     123      } else if (classificationProblemData != null) {
     124        var model = CreateClassificationModel(classificationProblemData,
     125                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
     126                                              out relClassError, out oobRelClassError);
     127        randomForest = model.randomForest;
     128      }
     129    }
     130
    90131    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
    91       double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     132      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
     133      AssertInputMatrix(inputData);
    92134
    93135      int n = inputData.GetLength(0);
     
    100142          x[column] = inputData[row, column];
    101143        }
    102         alglib.dfprocess(randomForest, x, ref y);
     144        alglib.dfprocess(RandomForest, x, ref y);
    103145        yield return y[0];
    104146      }
     
    106148
    107149    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
    108       double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     150      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
     151      AssertInputMatrix(inputData);
    109152
    110153      int n = inputData.GetLength(0);
    111154      int columns = inputData.GetLength(1);
    112155      double[] x = new double[columns];
    113       double[] y = new double[randomForest.innerobj.nclasses];
     156      double[] y = new double[RandomForest.innerobj.nclasses];
    114157
    115158      for (int row = 0; row < n; row++) {
     
    144187    }
    145188
    146     #region events
    147     public event EventHandler Changed;
    148     private void OnChanged(EventArgs e) {
    149       var handlers = Changed;
    150       if (handlers != null)
    151         handlers(this, e);
    152     }
    153     #endregion
    154 
    155     #region persistence
     189    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
     190      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
     191      return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError);
     192    }
     193
     194    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
     195      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
     196      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
     197      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
     198
     199      alglib.dfreport rep;
     200      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
     201
     202      rmsError = rep.rmserror;
     203      avgRelError = rep.avgrelerror;
     204      outOfBagAvgRelError = rep.oobavgrelerror;
     205      outOfBagRmsError = rep.oobrmserror;
     206
     207      return new RandomForestModel(dForest,seed, problemData,nTrees, r, m);
     208    }
     209
     210    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     211      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
     212      return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
     213    }
     214
     215    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
     216      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
     217
     218      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
     219      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
     220
     221      var classValues = problemData.ClassValues.ToArray();
     222      int nClasses = classValues.Length;
     223
     224      // map original class values to values [0..nClasses-1]
     225      var classIndices = new Dictionary<double, double>();
     226      for (int i = 0; i < nClasses; i++) {
     227        classIndices[classValues[i]] = i;
     228      }
     229
     230      int nRows = inputMatrix.GetLength(0);
     231      int nColumns = inputMatrix.GetLength(1);
     232      for (int row = 0; row < nRows; row++) {
     233        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
     234      }
     235
     236      alglib.dfreport rep;
     237      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
     238
     239      rmsError = rep.rmserror;
     240      outOfBagRmsError = rep.oobrmserror;
     241      relClassificationError = rep.relclserror;
     242      outOfBagRelClassificationError = rep.oobrelclserror;
     243
     244      return new RandomForestModel(dForest,seed, problemData,nTrees, r, m, classValues);
     245    }
     246
     247    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
     248      AssertParameters(r, m);
     249      AssertInputMatrix(inputMatrix);
     250
     251      int info = 0;
     252      alglib.math.rndobject = new System.Random(seed);
     253      var dForest = new alglib.decisionforest();
     254      rep = new alglib.dfreport();
     255      int nRows = inputMatrix.GetLength(0);
     256      int nColumns = inputMatrix.GetLength(1);
     257      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
     258      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
     259
     260      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
     261      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
     262      return dForest;
     263    }
     264
     265    private static void AssertParameters(double r, double m) {
     266      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
     267      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
     268    }
     269
     270    private static void AssertInputMatrix(double[,] inputMatrix) {
     271      if (inputMatrix.Cast<double>().Any(x => Double.IsNaN(x) || Double.IsInfinity(x)))
     272        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
     273    }
     274
     275    #region persistence for backwards compatibility
     276    // when the originalTrainingData is null this means the model was loaded from an old file
     277    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
     278    // in such cases we still store the compete model
     279    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
     280
     281    private string[] allowedInputVariables;
     282    [Storable(Name = "allowedInputVariables")]
     283    private string[] AllowedInputVariables {
     284      get {
     285        if (IsCompatibilityLoaded) return allowedInputVariables;
     286        else return originalTrainingData.AllowedInputVariables.ToArray();
     287      }
     288      set { allowedInputVariables = value; }
     289    }
    156290    [Storable]
    157291    private int RandomForestBufSize {
    158292      get {
    159         return randomForest.innerobj.bufsize;
     293        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
     294        else return 0;
    160295      }
    161296      set {
     
    166301    private int RandomForestNClasses {
    167302      get {
    168         return randomForest.innerobj.nclasses;
     303        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
     304        else return 0;
    169305      }
    170306      set {
     
    175311    private int RandomForestNTrees {
    176312      get {
    177         return randomForest.innerobj.ntrees;
     313        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
     314        else return 0;
    178315      }
    179316      set {
     
    184321    private int RandomForestNVars {
    185322      get {
    186         return randomForest.innerobj.nvars;
     323        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
     324        else return 0;
    187325      }
    188326      set {
     
    193331    private double[] RandomForestTrees {
    194332      get {
    195         return randomForest.innerobj.trees;
     333        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
     334        else return new double[] { };
    196335      }
    197336      set {
  • branches/Breadcrumbs/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs

    r9456 r11594  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    130130    public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
    131131      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
    132       if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1.");
    133       if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1.");
    134 
    135       alglib.math.rndobject = new System.Random(seed);
    136 
    137       Dataset dataset = problemData.Dataset;
    138       string targetVariable = problemData.TargetVariable;
    139       IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
    140       IEnumerable<int> rows = problemData.TrainingIndices;
    141       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
    142       if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    143         throw new NotSupportedException("Random forest regression does not support NaN or infinity values in the input dataset.");
    144 
    145       int info = 0;
    146       alglib.decisionforest dForest = new alglib.decisionforest();
    147       alglib.dfreport rep = new alglib.dfreport(); ;
    148       int nRows = inputMatrix.GetLength(0);
    149       int nColumns = inputMatrix.GetLength(1);
    150       int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
    151       int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
    152 
    153       alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, 1, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
    154       if (info != 1) throw new ArgumentException("Error in calculation of random forest regression solution");
    155 
    156       rmsError = rep.rmserror;
    157       avgRelError = rep.avgrelerror;
    158       outOfBagAvgRelError = rep.oobavgrelerror;
    159       outOfBagRmsError = rep.oobrmserror;
    160 
    161       return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables));
     132      var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     133      return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), model);
    162134    }
    163135    #endregion
  • branches/Breadcrumbs/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegressionSolution.cs

    r9456 r11594  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
  • branches/Breadcrumbs/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs

    r11400 r11594  
    3030using HeuristicLab.Core;
    3131using HeuristicLab.Data;
     32using HeuristicLab.Parameters;
     33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3234using HeuristicLab.Problems.DataAnalysis;
    3335using HeuristicLab.Random;
    3436
    3537namespace HeuristicLab.Algorithms.DataAnalysis {
    36   public class RFParameter : ICloneable {
    37     public double n; // number of trees
    38     public double m;
    39     public double r;
    40 
    41     public object Clone() { return new RFParameter { n = this.n, m = this.m, r = this.r }; }
     38  [Item("RFParameter", "A random forest parameter collection")]
     39  [StorableClass]
     40  public class RFParameter : ParameterCollection {
     41    public RFParameter() {
     42      base.Add(new FixedValueParameter<IntValue>("N", "The number of random forest trees", new IntValue(50)));
     43      base.Add(new FixedValueParameter<DoubleValue>("M", "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.1)));
     44      base.Add(new FixedValueParameter<DoubleValue>("R", "The ratio of the training set that will be used in the construction of individual trees (0<r<=1)", new DoubleValue(0.1)));
     45    }
     46
     47    [StorableConstructor]
     48    protected RFParameter(bool deserializing)
     49      : base(deserializing) {
     50    }
     51
     52    protected RFParameter(RFParameter original, Cloner cloner)
     53      : base(original, cloner) {
     54      this.N = original.N;
     55      this.R = original.R;
     56      this.M = original.M;
     57    }
     58
     59    public override IDeepCloneable Clone(Cloner cloner) {
     60      return new RFParameter(this, cloner);
     61    }
     62
     63    private IFixedValueParameter<IntValue> NParameter {
     64      get { return (IFixedValueParameter<IntValue>)base["N"]; }
     65    }
     66
     67    private IFixedValueParameter<DoubleValue> RParameter {
     68      get { return (IFixedValueParameter<DoubleValue>)base["R"]; }
     69    }
     70
     71    private IFixedValueParameter<DoubleValue> MParameter {
     72      get { return (IFixedValueParameter<DoubleValue>)base["M"]; }
     73    }
     74
     75    public int N {
     76      get { return NParameter.Value.Value; }
     77      set { NParameter.Value.Value = value; }
     78    }
     79
     80    public double R {
     81      get { return RParameter.Value.Value; }
     82      set { RParameter.Value.Value = value; }
     83    }
     84
     85    public double M {
     86      get { return MParameter.Value.Value; }
     87      set { MParameter.Value.Value = value; }
     88    }
    4289  }
    4390
    4491  public static class RandomForestUtil {
     92    private static readonly object locker = new object();
     93
    4594    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
    4695      avgTestMse = 0;
     
    62111      avgTestMse /= partitions.Length;
    63112    }
     113
    64114    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
    65115      avgTestAccuracy = 0;
     
    91141      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
    92142        var parameterValues = parameterCombination.ToList();
    93         double testMSE;
    94143        var parameters = new RFParameter();
    95144        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
    96145        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
    97         var model = RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed,
    98                                                             out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
    99         if (bestOutOfBagRmsError > outOfBagRmsError) {
    100           lock (bestParameters) {
     146        RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
     147
     148        lock (locker) {
     149          if (bestOutOfBagRmsError > outOfBagRmsError) {
    101150            bestOutOfBagRmsError = outOfBagRmsError;
    102151            bestParameters = (RFParameter)parameters.Clone();
     
    119168        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
    120169        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
    121         var model = RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed,
     170        RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed,
    122171                                                                out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
    123         if (bestOutOfBagRmsError > outOfBagRmsError) {
    124           lock (bestParameters) {
     172
     173        lock (locker) {
     174          if (bestOutOfBagRmsError > outOfBagRmsError) {
    125175            bestOutOfBagRmsError = outOfBagRmsError;
    126176            bestParameters = (RFParameter)parameters.Clone();
     
    133183    public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    134184      DoubleValue mse = new DoubleValue(Double.MaxValue);
    135       RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };
     185      RFParameter bestParameter = new RFParameter();
    136186
    137187      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     
    146196          setters[i](parameters, parameterValues[i]);
    147197        }
    148         CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testMSE);
    149         if (testMSE < mse.Value) {
    150           lock (mse) {
     198        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testMSE);
     199
     200        lock (locker) {
     201          if (testMSE < mse.Value) {
    151202            mse.Value = testMSE;
    152203            bestParameter = (RFParameter)parameters.Clone();
     
    159210    public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    160211      DoubleValue accuracy = new DoubleValue(0);
    161       RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };
     212      RFParameter bestParameter = new RFParameter();
    162213
    163214      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     
    172223          setters[i](parameters, parameterValues[i]);
    173224        }
    174         CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testAccuracy);
    175         if (testAccuracy > accuracy.Value) {
    176           lock (accuracy) {
     225        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testAccuracy);
     226
     227        lock (locker) {
     228          if (testAccuracy > accuracy.Value) {
    177229            accuracy.Value = testAccuracy;
    178230            bestParameter = (RFParameter)parameters.Clone();
     
    252304      var targetExp = Expression.Parameter(typeof(RFParameter));
    253305      var valueExp = Expression.Parameter(typeof(double));
    254       var fieldExp = Expression.Field(targetExp, field);
     306      var fieldExp = Expression.Property(targetExp, field);
    255307      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
    256308      var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
Note: See TracChangeset for help on using the changeset viewer.