Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/04/14 17:31:46 (10 years ago)
Author:
mkommend
Message:

#2237: Corrected newly introduced bug in RandomForestModel and reorganized RandomForestUtil.

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r11338 r11343  
    189189    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
    190190      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
    191       return CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError, problemData.TrainingIndices);
    192     }
    193 
    194     public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
    195       out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError, IEnumerable<int> trainingIndices) {
     191      return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError);
     192    }
     193
     194    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
     195      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
    196196      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
    197       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);
     197      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
    198198
    199199      alglib.dfreport rep;
     
    205205      outOfBagRmsError = rep.oobrmserror;
    206206
    207       return new RandomForestModel(dForest,
    208         seed, problemData,
    209         nTrees, r, m);
     207      return new RandomForestModel(dForest,seed, problemData,nTrees, r, m);
    210208    }
    211209
    212210    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    213211      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
    214       return CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError, problemData.TrainingIndices);
    215     }
    216 
    217     public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    218       out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError, IEnumerable<int> trainingIndices) {
     212      return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
     213    }
     214
     215    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
     216      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
    219217
    220218      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
     
    244242      outOfBagRelClassificationError = rep.oobrelclserror;
    245243
    246       return new RandomForestModel(dForest,
    247         seed, problemData,
    248         nTrees, r, m, classValues);
     244      return new RandomForestModel(dForest,seed, problemData,nTrees, r, m, classValues);
    249245    }
    250246
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs

    r11338 r11343  
    4141
    4242  public static class RandomForestUtil {
    43     private static Action<RFParameter, double> GenerateSetter(string field) {
    44       var targetExp = Expression.Parameter(typeof(RFParameter));
    45       var valueExp = Expression.Parameter(typeof(double));
    46       var fieldExp = Expression.Field(targetExp, field);
    47       var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
    48       var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
    49       return setter;
     43    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) {
     44      CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse);
     45    }
     46    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
     47      avgTestMse = 0;
     48      var ds = problemData.Dataset;
     49      var targetVariable = GetTargetVariableName(problemData);
     50      foreach (var tuple in partitions) {
     51        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
     52        var trainingRandomForestPartition = tuple.Item1;
     53        var testRandomForestPartition = tuple.Item2;
     54        var model = RandomForestModel.CreateRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     55        var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
     56        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
     57        OnlineCalculatorError calculatorError;
     58        double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
     59        if (calculatorError != OnlineCalculatorError.None)
     60          mse = double.NaN;
     61        avgTestMse += mse;
     62      }
     63      avgTestMse /= partitions.Length;
     64    }
     65
     66    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) {
     67      CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse);
     68    }
     69    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
     70      avgTestAccuracy = 0;
     71      var ds = problemData.Dataset;
     72      var targetVariable = GetTargetVariableName(problemData);
     73      foreach (var tuple in partitions) {
     74        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
     75        var trainingRandomForestPartition = tuple.Item1;
     76        var testRandomForestPartition = tuple.Item2;
     77        var model = RandomForestModel.CreateClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     78        var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition);
     79        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
     80        OnlineCalculatorError calculatorError;
     81        double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
     82        if (calculatorError != OnlineCalculatorError.None)
     83          accuracy = double.NaN;
     84        avgTestAccuracy += accuracy;
     85      }
     86      avgTestAccuracy /= partitions.Length;
     87    }
     88
     89    private static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
     90      DoubleValue mse = new DoubleValue(Double.MaxValue);
     91      RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults
     92
     93      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     94      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
     95      var crossProduct = parameterRanges.Values.CartesianProduct();
     96
     97      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
     98        var parameterValues = parameterCombination.ToList();
     99        double testMSE;
     100        var parameters = new RFParameter();
     101        for (int i = 0; i < setters.Count; ++i) {
     102          setters[i](parameters, parameterValues[i]);
     103        }
     104        CrossValidate(problemData, partitions, parameters, seed, out testMSE);
     105        if (testMSE < mse.Value) {
     106          lock (mse) {
     107            mse.Value = testMSE;
     108            bestParameter = (RFParameter)parameters.Clone();
     109          }
     110        }
     111      });
     112      return bestParameter;
     113    }
     114
     115    private static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
     116      DoubleValue accuracy = new DoubleValue(0);
     117      RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults
     118
     119      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     120      var crossProduct = parameterRanges.Values.CartesianProduct();
     121      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
     122
     123      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
     124        var parameterValues = parameterCombination.ToList();
     125        double testAccuracy;
     126        var parameters = new RFParameter();
     127        for (int i = 0; i < setters.Count; ++i) {
     128          setters[i](parameters, parameterValues[i]);
     129        }
     130        CrossValidate(problemData, partitions, parameters, seed, out testAccuracy);
     131        if (testAccuracy > accuracy.Value) {
     132          lock (accuracy) {
     133            accuracy.Value = testAccuracy;
     134            bestParameter = (RFParameter)parameters.Clone();
     135          }
     136        }
     137      });
     138      return bestParameter;
    50139    }
    51140
     
    57146    /// <param name="numberOfFolds">The number of folds to generate</param>
    58147    /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
    59     public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {
     148    private static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {
    60149      int size = problemData.TrainingPartition.Size;
    61150      int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder
     
    82171    }
    83172
    84     public static void CrossValidate(IDataAnalysisProblemData problemData, int numberOfFolds, RFParameter parameters, int seed, out double error) {
    85       var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
    86       CrossValidate(problemData, partitions, parameters, seed, out error);
    87     }
    88173
    89     // user should call the more specific CrossValidate methods
    90     public static void CrossValidate(IDataAnalysisProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double error) {
    91       CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out error);
    92     }
    93 
    94     public static void CrossValidate(IDataAnalysisProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double error) {
    95       var regressionProblemData = problemData as IRegressionProblemData;
    96       var classificationProblemData = problemData as IClassificationProblemData;
    97       if (regressionProblemData != null)
    98         CrossValidate(regressionProblemData, partitions, nTrees, m, r, seed, out error);
    99       else if (classificationProblemData != null)
    100         CrossValidate(classificationProblemData, partitions, nTrees, m, r, seed, out error);
    101       else throw new ArgumentException("Problem data is neither regression or classification problem data.");
    102     }
    103 
    104     private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) {
    105       CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse);
    106     }
    107 
    108     private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) {
    109       CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse);
    110     }
    111 
    112     private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
    113       avgTestMse = 0;
    114       var ds = problemData.Dataset;
    115       var targetVariable = GetTargetVariableName(problemData);
    116       foreach (var tuple in partitions) {
    117         double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
    118         var trainingRandomForestPartition = tuple.Item1;
    119         var testRandomForestPartition = tuple.Item2;
    120         var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError, trainingRandomForestPartition);
    121         var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
    122         var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
    123         OnlineCalculatorError calculatorError;
    124         double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
    125         if (calculatorError != OnlineCalculatorError.None)
    126           mse = double.NaN;
    127         avgTestMse += mse;
    128       }
    129       avgTestMse /= partitions.Length;
    130     }
    131 
    132     private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
    133       avgTestAccuracy = 0;
    134       var ds = problemData.Dataset;
    135       var targetVariable = GetTargetVariableName(problemData);
    136       foreach (var tuple in partitions) {
    137         double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
    138         var trainingRandomForestPartition = tuple.Item1;
    139         var testRandomForestPartition = tuple.Item2;
    140         var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError, trainingRandomForestPartition);
    141         var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition);
    142         var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
    143         OnlineCalculatorError calculatorError;
    144         double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
    145         if (calculatorError != OnlineCalculatorError.None)
    146           accuracy = double.NaN;
    147         avgTestAccuracy += accuracy;
    148       }
    149       avgTestAccuracy /= partitions.Length;
    150     }
    151 
    152     public static RFParameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    153       var regressionProblemData = problemData as IRegressionProblemData;
    154       var classificationProblemData = problemData as IClassificationProblemData;
    155 
    156       if (regressionProblemData != null)
    157         return GridSearch(regressionProblemData, numberOfFolds, parameterRanges, seed, maxDegreeOfParallelism);
    158       if (classificationProblemData != null)
    159         return GridSearch(classificationProblemData, numberOfFolds, parameterRanges, seed, maxDegreeOfParallelism);
    160 
    161       throw new ArgumentException("Problem data is neither regression or classification problem data.");
    162     }
    163 
    164     private static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    165       DoubleValue mse = new DoubleValue(Double.MaxValue);
    166       RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults
    167 
    168       var pNames = parameterRanges.Keys.ToList();
    169       var pRanges = pNames.Select(x => parameterRanges[x]);
    170       var setters = pNames.Select(GenerateSetter).ToList();
    171       var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
    172       var crossProduct = pRanges.CartesianProduct();
    173 
    174       Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => {
    175         var list = nuple.ToList();
    176         double testMSE;
    177         var parameters = new RFParameter();
    178         for (int i = 0; i < pNames.Count; ++i) {
    179           var s = setters[i];
    180           s(parameters, list[i]);
    181         }
    182         CrossValidate(problemData, partitions, parameters, seed, out testMSE);
    183         if (testMSE < mse.Value) {
    184           lock (mse) { mse.Value = testMSE; }
    185           lock (bestParameter) { bestParameter = (RFParameter)parameters.Clone(); }
    186         }
    187       });
    188       return bestParameter;
    189     }
    190 
    191     private static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    192       DoubleValue accuracy = new DoubleValue(0);
    193       RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults
    194 
    195       var pNames = parameterRanges.Keys.ToList();
    196       var pRanges = pNames.Select(x => parameterRanges[x]);
    197       var setters = pNames.Select(GenerateSetter).ToList();
    198       var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
    199       var crossProduct = pRanges.CartesianProduct();
    200 
    201       Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => {
    202         var list = nuple.ToList();
    203         double testAccuracy;
    204         var parameters = new RFParameter();
    205         for (int i = 0; i < pNames.Count; ++i) {
    206           var s = setters[i];
    207           s(parameters, list[i]);
    208         }
    209         CrossValidate(problemData, partitions, parameters, seed, out testAccuracy);
    210         if (testAccuracy > accuracy.Value) {
    211           lock (accuracy) { accuracy.Value = testAccuracy; }
    212           lock (bestParameter) { bestParameter = (RFParameter)parameters.Clone(); }
    213         }
    214       });
    215       return bestParameter;
     174    private static Action<RFParameter, double> GenerateSetter(string field) {
     175      var targetExp = Expression.Parameter(typeof(RFParameter));
     176      var valueExp = Expression.Parameter(typeof(double));
     177      var fieldExp = Expression.Field(targetExp, field);
     178      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
     179      var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
     180      return setter;
    216181    }
    217182
Note: See TracChangeset for help on using the changeset viewer.