Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/25/15 13:46:24 (9 years ago)
Author:
mkommend
Message:

#2276: Reintegrated branch for dataset refactoring.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs

    r12012 r12509  
    9090
    9191  public static class RandomForestUtil {
    92     private static readonly object locker = new object();
    93 
    9492    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
    9593      avgTestMse = 0;
     
    132130    }
    133131
    134     // grid search without cross-validation since in the case of random forests, the out-of-bag estimate is unbiased
     132    /// <summary>
     133    /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased)
     134    /// </summary>
     135    /// <param name="problemData">The regression problem data</param>
     136    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
     137    /// <param name="seed">The random seed (required by the random forest model)</param>
     138    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
    135139    public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    136140      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     
    139143      RFParameter bestParameters = new RFParameter();
    140144
     145      var locker = new object();
    141146      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
    142147        var parameterValues = parameterCombination.ToList();
     
    156161    }
    157162
     163    /// <summary>
     164    /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased)
     165    /// </summary>
     166    /// <param name="problemData">The classification problem data</param>
     167    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
     168    /// <param name="seed">The random seed (required by the random forest model)</param>
     169    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
    158170    public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    159171      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     
    163175      RFParameter bestParameters = new RFParameter();
    164176
     177      var locker = new object();
    165178      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
    166179        var parameterValues = parameterCombination.ToList();
     
    181194    }
    182195
     196    /// <summary>
     197    /// Grid search with crossvalidation
     198    /// </summary>
     199    /// <param name="problemData">The regression problem data</param>
     200    /// <param name="numberOfFolds">The number of folds for crossvalidation</param>
     201    /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param>
     202    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
     203    /// <param name="seed">The random seed (required by the random forest model)</param>
     204    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
     205    /// <returns>The best parameter values found by the grid search</returns>
    183206    public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    184207      DoubleValue mse = new DoubleValue(Double.MaxValue);
     
    189212      var crossProduct = parameterRanges.Values.CartesianProduct();
    190213
     214      var locker = new object();
    191215      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
    192216        var parameterValues = parameterCombination.ToList();
     
    208232    }
    209233
     234    /// <summary>
     235    /// Grid search with crossvalidation
     236    /// </summary>
     237    /// <param name="problemData">The classification problem data</param>
     238    /// <param name="numberOfFolds">The number of folds for crossvalidation</param>
     239    /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param>
     240    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
     241    /// <param name="seed">The random seed (for shuffling)</param>
     242    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
    210243    public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    211244      DoubleValue accuracy = new DoubleValue(0);
     
    216249      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds);
    217250
     251      var locker = new object();
    218252      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
    219253        var parameterValues = parameterCombination.ToList();
Note: See TracChangeset for help on using the changeset viewer.