Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
02/05/15 10:29:24 (10 years ago)
Author:
mkommend
Message:

#2237: Merged r11315, r11317, r11338, r11343, r11362, r11426, r11443, r11445, r11446, r11448 into stable.

Location:
stable
Files:
3 edited
1 copied

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Algorithms.DataAnalysis

  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r11170 r11901  
    7676      // we assume that the trees array (double[]) is immutable in alglib
    7777      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
    78      
     78
    7979      // allowedInputVariables is immutable so we don't need to clone
    8080      allowedInputVariables = original.allowedInputVariables;
     
    188188
    189189    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
    190       out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) {
    191 
     190      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
     191      return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError);
     192    }
     193
     194    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
     195      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
    192196      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
    193       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);
     197      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
    194198
    195199      alglib.dfreport rep;
     
    201205      outOfBagRmsError = rep.oobrmserror;
    202206
    203       return new RandomForestModel(dForest,
    204         seed, problemData,
    205         nTrees, r, m);
     207      return new RandomForestModel(dForest,seed, problemData,nTrees, r, m);
    206208    }
    207209
    208210    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    209211      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
     212      return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
     213    }
     214
     215    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
     216      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
    210217
    211218      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
    212       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);
     219      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
    213220
    214221      var classValues = problemData.ClassValues.ToArray();
     
    235242      outOfBagRelClassificationError = rep.oobrelclserror;
    236243
    237       return new RandomForestModel(dForest,
    238         seed, problemData,
    239         nTrees, r, m, classValues);
     244      return new RandomForestModel(dForest,seed, problemData,nTrees, r, m, classValues);
    240245    }
    241246
     
    264269
    265270    private static void AssertInputMatrix(double[,] inputMatrix) {
    266       if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
     271      if (inputMatrix.Cast<double>().Any(x => Double.IsNaN(x) || Double.IsInfinity(x)))
    267272        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
    268273    }
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs

    r11315 r11901  
    2828using System.Threading.Tasks;
    2929using HeuristicLab.Common;
     30using HeuristicLab.Core;
    3031using HeuristicLab.Data;
     32using HeuristicLab.Parameters;
     33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3134using HeuristicLab.Problems.DataAnalysis;
     35using HeuristicLab.Random;
    3236
    3337namespace HeuristicLab.Algorithms.DataAnalysis {
    34   public class RFParameter : ICloneable {
    35     public double n; // number of trees
    36     public double m;
    37     public double r;
    38 
    39     public object Clone() { return new RFParameter { n = this.n, m = this.m, r = this.r }; }
     38  [Item("RFParameter", "A random forest parameter collection")]
     39  [StorableClass]
     40  public class RFParameter : ParameterCollection {
     41    public RFParameter() {
     42      base.Add(new FixedValueParameter<IntValue>("N", "The number of random forest trees", new IntValue(50)));
     43      base.Add(new FixedValueParameter<DoubleValue>("M", "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.1)));
     44      base.Add(new FixedValueParameter<DoubleValue>("R", "The ratio of the training set that will be used in the construction of individual trees (0<r<=1)", new DoubleValue(0.1)));
     45    }
     46
     47    [StorableConstructor]
     48    protected RFParameter(bool deserializing)
     49      : base(deserializing) {
     50    }
     51
     52    protected RFParameter(RFParameter original, Cloner cloner)
     53      : base(original, cloner) {
     54      this.N = original.N;
     55      this.R = original.R;
     56      this.M = original.M;
     57    }
     58
     59    public override IDeepCloneable Clone(Cloner cloner) {
     60      return new RFParameter(this, cloner);
     61    }
     62
     63    private IFixedValueParameter<IntValue> NParameter {
     64      get { return (IFixedValueParameter<IntValue>)base["N"]; }
     65    }
     66
     67    private IFixedValueParameter<DoubleValue> RParameter {
     68      get { return (IFixedValueParameter<DoubleValue>)base["R"]; }
     69    }
     70
     71    private IFixedValueParameter<DoubleValue> MParameter {
     72      get { return (IFixedValueParameter<DoubleValue>)base["M"]; }
     73    }
     74
     75    public int N {
     76      get { return NParameter.Value.Value; }
     77      set { NParameter.Value.Value = value; }
     78    }
     79
     80    public double R {
     81      get { return RParameter.Value.Value; }
     82      set { RParameter.Value.Value = value; }
     83    }
     84
     85    public double M {
     86      get { return MParameter.Value.Value; }
     87      set { MParameter.Value.Value = value; }
     88    }
    4089  }
    4190
    4291  public static class RandomForestUtil {
     92    private static readonly object locker = new object();
     93
     94    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
     95      avgTestMse = 0;
     96      var ds = problemData.Dataset;
     97      var targetVariable = GetTargetVariableName(problemData);
     98      foreach (var tuple in partitions) {
     99        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
     100        var trainingRandomForestPartition = tuple.Item1;
     101        var testRandomForestPartition = tuple.Item2;
     102        var model = RandomForestModel.CreateRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     103        var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
     104        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
     105        OnlineCalculatorError calculatorError;
     106        double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
     107        if (calculatorError != OnlineCalculatorError.None)
     108          mse = double.NaN;
     109        avgTestMse += mse;
     110      }
     111      avgTestMse /= partitions.Length;
     112    }
     113
     114    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
     115      avgTestAccuracy = 0;
     116      var ds = problemData.Dataset;
     117      var targetVariable = GetTargetVariableName(problemData);
     118      foreach (var tuple in partitions) {
     119        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
     120        var trainingRandomForestPartition = tuple.Item1;
     121        var testRandomForestPartition = tuple.Item2;
     122        var model = RandomForestModel.CreateClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     123        var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition);
     124        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
     125        OnlineCalculatorError calculatorError;
     126        double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
     127        if (calculatorError != OnlineCalculatorError.None)
     128          accuracy = double.NaN;
     129        avgTestAccuracy += accuracy;
     130      }
     131      avgTestAccuracy /= partitions.Length;
     132    }
     133
     134    // grid search without cross-validation since in the case of random forests, the out-of-bag estimate is unbiased
     135    public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
     136      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     137      var crossProduct = parameterRanges.Values.CartesianProduct();
     138      double bestOutOfBagRmsError = double.MaxValue;
     139      RFParameter bestParameters = new RFParameter();
     140
     141      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
     142        var parameterValues = parameterCombination.ToList();
     143        var parameters = new RFParameter();
     144        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
     145        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
     146        RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
     147
     148        lock (locker) {
     149          if (bestOutOfBagRmsError > outOfBagRmsError) {
     150            bestOutOfBagRmsError = outOfBagRmsError;
     151            bestParameters = (RFParameter)parameters.Clone();
     152          }
     153        }
     154      });
     155      return bestParameters;
     156    }
     157
     158    public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
     159      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     160      var crossProduct = parameterRanges.Values.CartesianProduct();
     161
     162      double bestOutOfBagRmsError = double.MaxValue;
     163      RFParameter bestParameters = new RFParameter();
     164
     165      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
     166        var parameterValues = parameterCombination.ToList();
     167        var parameters = new RFParameter();
     168        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
     169        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
     170        RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed,
     171                                                                out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
     172
     173        lock (locker) {
     174          if (bestOutOfBagRmsError > outOfBagRmsError) {
     175            bestOutOfBagRmsError = outOfBagRmsError;
     176            bestParameters = (RFParameter)parameters.Clone();
     177          }
     178        }
     179      });
     180      return bestParameters;
     181    }
     182
     183    public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
     184      DoubleValue mse = new DoubleValue(Double.MaxValue);
     185      RFParameter bestParameter = new RFParameter();
     186
     187      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     188      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
     189      var crossProduct = parameterRanges.Values.CartesianProduct();
     190
     191      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
     192        var parameterValues = parameterCombination.ToList();
     193        double testMSE;
     194        var parameters = new RFParameter();
     195        for (int i = 0; i < setters.Count; ++i) {
     196          setters[i](parameters, parameterValues[i]);
     197        }
     198        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testMSE);
     199
     200        lock (locker) {
     201          if (testMSE < mse.Value) {
     202            mse.Value = testMSE;
     203            bestParameter = (RFParameter)parameters.Clone();
     204          }
     205        }
     206      });
     207      return bestParameter;
     208    }
     209
     210    public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
     211      DoubleValue accuracy = new DoubleValue(0);
     212      RFParameter bestParameter = new RFParameter();
     213
     214      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     215      var crossProduct = parameterRanges.Values.CartesianProduct();
     216      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds);
     217
     218      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
     219        var parameterValues = parameterCombination.ToList();
     220        double testAccuracy;
     221        var parameters = new RFParameter();
     222        for (int i = 0; i < setters.Count; ++i) {
     223          setters[i](parameters, parameterValues[i]);
     224        }
     225        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testAccuracy);
     226
     227        lock (locker) {
     228          if (testAccuracy > accuracy.Value) {
     229            accuracy.Value = testAccuracy;
     230            bestParameter = (RFParameter)parameters.Clone();
     231          }
     232        }
     233      });
     234      return bestParameter;
     235    }
     236
     237    private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
     238      var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
     239      var partitions = new Tuple<IEnumerable<int>, IEnumerable<int>>[numberOfFolds];
     240
     241      for (int i = 0; i < numberOfFolds; ++i) {
     242        int p = i; // avoid "access to modified closure" warning
     243        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
     244        var testRows = folds[i];
     245        partitions[i] = new Tuple<IEnumerable<int>, IEnumerable<int>>(trainingRows, testRows);
     246      }
     247      return partitions;
     248    }
     249
     250    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
     251      var random = new MersenneTwister((uint)Environment.TickCount);
     252      if (problemData is IRegressionProblemData) {
     253        var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
     254        return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
     255      }
     256      if (problemData is IClassificationProblemData) {
     257        // when shuffle is enabled do stratified folds generation, some folds may have zero elements
     258        // otherwise, generate folds normally
     259        return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
     260      }
     261      throw new ArgumentException("Problem data is neither regression or classification problem data.");
     262    }
     263
     264    /// <summary>
     265    /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
     266    /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
     267    /// the corresponding parts from each class label.
     268    /// </summary>
     269    /// <param name="problemData">The classification problem data.</param>
     270    /// <param name="numberOfFolds">The number of folds in which to split the data.</param>
     271    /// <param name="random">The random generator used to shuffle the folds.</param>
     272    /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns>
     273    private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
     274      var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     275      var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
     276      IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
     277      var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
     278      while (enumerators.All(e => e.MoveNext())) {
     279        yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
     280      }
     281    }
     282
     283    private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) {
     284      // if number of folds is greater than the number of values, some empty folds will be returned
     285      if (valuesCount < numberOfFolds) {
     286        for (int i = 0; i < numberOfFolds; ++i)
     287          yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>();
     288      } else {
     289        int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
     290        int start = 0, end = f;
     291        for (int i = 0; i < numberOfFolds; ++i) {
     292          if (r > 0) {
     293            ++end;
     294            --r;
     295          }
     296          yield return values.Skip(start).Take(end - start);
     297          start = end;
     298          end += f;
     299        }
     300      }
     301    }
     302
    43303    private static Action<RFParameter, double> GenerateSetter(string field) {
    44304      var targetExp = Expression.Parameter(typeof(RFParameter));
    45305      var valueExp = Expression.Parameter(typeof(double));
    46 
    47       // Expression.Property can be used here as well
    48       var fieldExp = Expression.Field(targetExp, field);
     306      var fieldExp = Expression.Property(targetExp, field);
    49307      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
    50308      var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
     
    52310    }
    53311
    54     /// <summary>
    55     /// Generate a collection of training indices corresponding to folds in the data (used for crossvalidation)
    56     /// </summary>
    57     /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks>
    58     /// <param name="problemData">The problem data</param>
    59     /// <param name="nFolds">The number of folds to generate</param>
    60     /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
    61     public static IEnumerable<IEnumerable<int>> GenerateFolds(IRegressionProblemData problemData, int nFolds) {
    62       int size = problemData.TrainingPartition.Size;
    63 
    64       int foldSize = size / nFolds; // rounding to integer
    65       var trainingIndices = problemData.TrainingIndices;
    66 
    67       for (int i = 0; i < nFolds; ++i) {
    68         int n = i * foldSize;
    69         int s = n + 2 * foldSize > size ? foldSize + size % foldSize : foldSize;
    70         yield return trainingIndices.Skip(n).Take(s);
    71       }
    72     }
    73 
    74     public static void CrossValidate(IRegressionProblemData problemData, IEnumerable<IEnumerable<int>> folds, RFParameter parameter, int seed, out double avgTestMSE) {
    75       CrossValidate(problemData, folds, (int)Math.Round(parameter.n), parameter.m, parameter.r, seed, out avgTestMSE);
    76     }
    77 
    78     public static void CrossValidate(IRegressionProblemData problemData, IEnumerable<IEnumerable<int>> folds, int nTrees, double m, double r, int seed, out double avgTestMSE) {
    79       avgTestMSE = 0;
    80 
    81       var ds = problemData.Dataset;
    82       var targetVariable = problemData.TargetVariable;
    83 
    84       var partitions = folds.ToList();
    85 
    86       for (int i = 0; i < partitions.Count; ++i) {
    87         double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
    88         var test = partitions[i];
    89         var training = new List<int>();
    90         for (int j = 0; j < i; ++j)
    91           training.AddRange(partitions[j]);
    92 
    93         for (int j = i + 1; j < partitions.Count; ++j)
    94           training.AddRange(partitions[j]);
    95 
    96         var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, m, r, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError, training);
    97         var estimatedValues = model.GetEstimatedValues(ds, test);
    98         var outputValues = ds.GetDoubleValues(targetVariable, test);
    99 
    100         OnlineCalculatorError calculatorError;
    101         double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, outputValues, out calculatorError);
    102         if (calculatorError != OnlineCalculatorError.None)
    103           mse = double.NaN;
    104         avgTestMSE += mse;
    105       }
    106 
    107       avgTestMSE /= partitions.Count;
    108     }
    109 
    110     public static RFParameter GridSearch(IRegressionProblemData problemData, IEnumerable<IEnumerable<int>> folds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    111       DoubleValue mse = new DoubleValue(Double.MaxValue);
    112       RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults
    113 
    114       var pNames = parameterRanges.Keys.ToList();
    115       var pRanges = pNames.Select(x => parameterRanges[x]);
    116       var setters = pNames.Select(GenerateSetter).ToList();
    117 
    118       var crossProduct = pRanges.CartesianProduct();
    119 
    120       Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => {
    121         var list = nuple.ToList();
    122         double testMSE;
    123         var parameters = new RFParameter();
    124         for (int i = 0; i < pNames.Count; ++i) {
    125           var s = setters[i];
    126           s(parameters, list[i]);
    127         }
    128         CrossValidate(problemData, folds, parameters, seed, out testMSE);
    129         if (testMSE < mse.Value) {
    130           lock (mse) {
    131             mse.Value = testMSE;
    132           }
    133           lock (bestParameter) {
    134             bestParameter = (RFParameter)parameters.Clone();
    135           }
    136         }
    137       });
    138       return bestParameter;
     312    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
     313      var regressionProblemData = problemData as IRegressionProblemData;
     314      var classificationProblemData = problemData as IClassificationProblemData;
     315
     316      if (regressionProblemData != null)
     317        return regressionProblemData.TargetVariable;
     318      if (classificationProblemData != null)
     319        return classificationProblemData.TargetVariable;
     320
     321      throw new ArgumentException("Problem data is neither regression or classification problem data.");
    139322    }
    140323  }
Note: See TracChangeset for help on using the changeset viewer.