Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
11/24/14 15:16:59 (10 years ago)
Author:
bburlacu
Message:

#2276: Commit initial version of IDataset interface and code refactoring.

Location:
branches/HeuristicLab.DatasetRefactor/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.DatasetRefactor/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r11343 r11571  
    129129    }
    130130
    131     public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
     131    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    132132      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
    133133      AssertInputMatrix(inputData);
     
    147147    }
    148148
    149     public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
     149    public IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
    150150      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
    151151      AssertInputMatrix(inputData);
     
    205205      outOfBagRmsError = rep.oobrmserror;
    206206
    207       return new RandomForestModel(dForest,seed, problemData,nTrees, r, m);
     207      return new RandomForestModel(dForest, seed, problemData, nTrees, r, m);
    208208    }
    209209
     
    242242      outOfBagRelClassificationError = rep.oobrelclserror;
    243243
    244       return new RandomForestModel(dForest,seed, problemData,nTrees, r, m, classValues);
     244      return new RandomForestModel(dForest, seed, problemData, nTrees, r, m, classValues);
    245245    }
    246246
  • branches/HeuristicLab.DatasetRefactor/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs

    r11445 r11571  
    9090
    9191  public static class RandomForestUtil {
    92     private static readonly object locker = new object();
    93 
    9492    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
    9593      avgTestMse = 0;
     
    132130    }
    133131
    134     // grid search without cross-validation since in the case of random forests, the out-of-bag estimate is unbiased
     132    /// <summary>
     133    /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased)
     134    /// </summary>
     135    /// <param name="problemData">The regression problem data</param>
     136    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
     137    /// <param name="seed">The random seed (required by the random forest model)</param>
     138    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
    135139    public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    136140      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     
    139143      RFParameter bestParameters = new RFParameter();
    140144
     145      var locker = new object();
    141146      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
    142147        var parameterValues = parameterCombination.ToList();
     
    156161    }
    157162
     163    /// <summary>
     164    /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased)
     165    /// </summary>
     166    /// <param name="problemData">The classification problem data</param>
     167    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
     168    /// <param name="seed">The random seed (required by the random forest model)</param>
     169    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
    158170    public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    159171      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     
    163175      RFParameter bestParameters = new RFParameter();
    164176
     177      var locker = new object();
    165178      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
    166179        var parameterValues = parameterCombination.ToList();
     
    181194    }
    182195
     196    /// <summary>
     197    /// Grid search with crossvalidation
     198    /// </summary>
     199    /// <param name="problemData">The regression problem data</param>
     200    /// <param name="numberOfFolds">The number of folds for crossvalidation</param>
     201    /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param>
     202    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
     203    /// <param name="seed">The random seed (required by the random forest model)</param>
     204    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
     205    /// <returns>The best parameter values found by the grid search</returns>
    183206    public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    184207      DoubleValue mse = new DoubleValue(Double.MaxValue);
     
    189212      var crossProduct = parameterRanges.Values.CartesianProduct();
    190213
     214      var locker = new object();
    191215      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
    192216        var parameterValues = parameterCombination.ToList();
     
    208232    }
    209233
     234    /// <summary>
     235    /// Grid search with crossvalidation
     236    /// </summary>
     237    /// <param name="problemData">The classification problem data</param>
     238    /// <param name="numberOfFolds">The number of folds for crossvalidation</param>
     239    /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param>
     240    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
     241    /// <param name="seed">The random seed (for shuffling)</param>
     242    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
    210243    public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    211244      DoubleValue accuracy = new DoubleValue(0);
     
    216249      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds);
    217250
     251      var locker = new object();
    218252      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
    219253        var parameterValues = parameterCombination.ToList();
Note: See TracChangeset for help on using the changeset viewer.