Free cookie consent management tool by TermsFeed Policy Generator

Changeset 17154


Ignore:
Timestamp:
07/23/19 20:07:08 (5 years ago)
Author:
gkronber
Message:

#2952: merged relevant revisions from branch to trunk

Merged revision(s) 17045-17153 from branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis:
#2952: Intermediate commit of refactoring RF models that is not yet finished.

........
#2952: Corrected evaluation in RF models.

........
#2952: Finished implementation of different RF models.

........
#2952 Fixed triggering model recalculation when cloning.
........
#2952: merged r17137 from trunk to branch
........
#2952: re-added backwards compatibility code for very old versions of GBT and RF
........
#2952: hide parameter in backwards compatibility hook
........

17045-17153

Location:
trunk/HeuristicLab.Algorithms.DataAnalysis
Files:
3 added
8 edited

Legend:

Unmodified
Added
Removed
  • trunk/HeuristicLab.Algorithms.DataAnalysis

  • trunk/HeuristicLab.Algorithms.DataAnalysis/3.4

  • trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs

    r17044 r17154  
    196196        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value)));
    197197        Parameters[ModelCreationParameterName].Hidden = true;
     198      } else if (!Parameters.ContainsKey(ModelCreationParameterName)) {
     199        // very old version contains neither ModelCreationParameter nor CreateSolutionParameter
     200        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
     201        Parameters[ModelCreationParameterName].Hidden = true;
    198202      }
    199203      #endregion
  • trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r17030 r17154  
    299299    <Compile Include="Plugin.cs" />
    300300    <Compile Include="Properties\AssemblyInfo.cs" />
     301    <Compile Include="RandomForest\ModelCreation.cs" />
    301302    <Compile Include="RandomForest\RandomForestClassificationSolution.cs" />
    302303    <Compile Include="RandomForest\RandomForestClassification.cs" />
    303304    <Compile Include="RandomForest\RandomForestModel.cs" />
     305    <Compile Include="RandomForest\RandomForestModelFull.cs" />
    304306    <Compile Include="RandomForest\RandomForestRegression.cs" />
    305307    <Compile Include="RandomForest\RandomForestRegressionSolution.cs" />
     308    <Compile Include="RandomForest\RandomForestModelSurrogate.cs" />
    306309    <Compile Include="RandomForest\RandomForestUtil.cs" />
    307310    <Compile Include="SupportVectorMachine\SupportVectorClassification.cs" />
  • trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs

    r16565 r17154  
    2020#endregion
    2121
     22using System.Collections.Generic;
     23using System.Linq;
    2224using System.Threading;
     25using HEAL.Attic;
     26using HeuristicLab.Algorithms.DataAnalysis.RandomForest;
    2327using HeuristicLab.Common;
    2428using HeuristicLab.Core;
     
    2630using HeuristicLab.Optimization;
    2731using HeuristicLab.Parameters;
    28 using HEAL.Attic;
    2932using HeuristicLab.Problems.DataAnalysis;
    3033
     
    4346    private const string SeedParameterName = "Seed";
    4447    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
    45     private const string CreateSolutionParameterName = "CreateSolution";
     48    private const string ModelCreationParameterName = "ModelCreation";
    4649
    4750    #region parameter properties
     
    6164      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    6265    }
    63     public IFixedValueParameter<BoolValue> CreateSolutionParameter {
    64       get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
     66    private IFixedValueParameter<EnumValue<ModelCreation>> ModelCreationParameter {
     67      get { return (IFixedValueParameter<EnumValue<ModelCreation>>)Parameters[ModelCreationParameterName]; }
    6568    }
    6669    #endregion
     
    8689      set { SetSeedRandomlyParameter.Value.Value = value; }
    8790    }
    88     public bool CreateSolution {
    89       get { return CreateSolutionParameter.Value.Value; }
    90       set { CreateSolutionParameter.Value.Value = value; }
     91    public ModelCreation ModelCreation {
     92      get { return ModelCreationParameter.Value.Value; }
     93      set { ModelCreationParameter.Value.Value = value; }
    9194    }
    9295    #endregion
     
    105108      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
    106109      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    107       Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    108       Parameters[CreateSolutionParameterName].Hidden = true;
     110      Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
     111      Parameters[ModelCreationParameterName].Hidden = true;
    109112
    110113      Problem = new ClassificationProblem();
     
    121124      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
    122125        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    123       if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
    124         Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    125         Parameters[CreateSolutionParameterName].Hidden = true;
     126
     127      // parameter type has been changed
     128      if (Parameters.ContainsKey("CreateSolution")) {
     129        var createSolutionParam = Parameters["CreateSolution"] as FixedValueParameter<BoolValue>;
     130        Parameters.Remove(createSolutionParam);
     131
     132        ModelCreation value = createSolutionParam.Value.Value ? ModelCreation.Model : ModelCreation.QualityOnly;
     133        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value)));
     134        Parameters[ModelCreationParameterName].Hidden = true;
     135      } else if (!Parameters.ContainsKey(ModelCreationParameterName)) {
     136        // very old version contains neither ModelCreationParameter nor CreateSolutionParameter
     137        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
     138        Parameters[ModelCreationParameterName].Hidden = true;
    126139      }
    127140      #endregion
     
    138151
    139152      var model = CreateRandomForestClassificationModel(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
     153
    140154      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
    141155      Results.Add(new Result("Relative classification error", "Relative classification error of the random forest regression solution on the training set.", new PercentValue(relClassificationError)));
     
    143157      Results.Add(new Result("Relative classification error (out-of-bag)", "The out-of-bag relative classification error  of the random forest regression solution.", new PercentValue(outOfBagRelClassificationError)));
    144158
    145       if (CreateSolution) {
    146         var solution = new RandomForestClassificationSolution(model, (IClassificationProblemData)Problem.ProblemData.Clone());
     159
     160      IClassificationSolution solution = null;
     161      if (ModelCreation == ModelCreation.Model) {
     162        solution = model.CreateClassificationSolution(Problem.ProblemData);
     163      } else if (ModelCreation == ModelCreation.SurrogateModel) {
     164        var problemData = Problem.ProblemData;
     165        var surrogateModel = new RandomForestModelSurrogate(model, problemData.TargetVariable, problemData, Seed, NumberOfTrees, R, M, problemData.ClassValues.ToArray());
     166
     167        solution = surrogateModel.CreateClassificationSolution(problemData);
     168      }
     169
     170      if (solution != null) {
    147171        Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
    148172      }
     
    157181    }
    158182
    159     public static RandomForestModel CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     183    public static RandomForestModelFull CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     184 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
     185      var model = CreateRandomForestClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     186      return model;
     187    }
     188
     189    public static RandomForestModelFull CreateRandomForestClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
    160190      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
    161       return RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed,
    162        rmsError: out rmsError, relClassificationError: out relClassificationError, outOfBagRmsError: out outOfBagRmsError, outOfBagRelClassificationError: out outOfBagRelClassificationError);
     191
     192      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
     193      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
     194
     195      var classValues = problemData.ClassValues.ToArray();
     196      int nClasses = classValues.Length;
     197
     198      // map original class values to values [0..nClasses-1]
     199      var classIndices = new Dictionary<double, double>();
     200      for (int i = 0; i < nClasses; i++) {
     201        classIndices[classValues[i]] = i;
     202      }
     203
     204      int nRows = inputMatrix.GetLength(0);
     205      int nColumns = inputMatrix.GetLength(1);
     206      for (int row = 0; row < nRows; row++) {
     207        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
     208      }
     209
     210      alglib.dfreport rep;
     211      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
     212
     213      rmsError = rep.rmserror;
     214      outOfBagRmsError = rep.oobrmserror;
     215      relClassificationError = rep.relclserror;
     216      outOfBagRelClassificationError = rep.oobrelclserror;
     217
     218      return new RandomForestModelFull(dForest, problemData.TargetVariable, problemData.AllowedInputVariables, classValues);
    163219    }
    164220    #endregion
  • trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r16763 r17154  
    2323using System.Collections.Generic;
    2424using System.Linq;
     25using HEAL.Attic;
    2526using HeuristicLab.Common;
    2627using HeuristicLab.Core;
    2728using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    28 using HEAL.Attic;
    2929using HeuristicLab.Problems.DataAnalysis;
    3030using HeuristicLab.Problems.DataAnalysis.Symbolic;
     
    3434  /// Represents a random forest model for regression and classification
    3535  /// </summary>
    36   [StorableType("A4F688CD-1F42-4103-8449-7DE52AEF6C69")]
     36  [Obsolete("This class only exists for backwards compatibility reasons for stored models with the XML Persistence. Use RFModelSurrogate or RFModelFull instead.")]
     37  [StorableType("9AA4CCC2-CD75-4471-8DF6-949E5B783642")]
    3738  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
    3839  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
     
    139140    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    140141      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
    141       AssertInputMatrix(inputData);
     142      RandomForestUtil.AssertInputMatrix(inputData);
    142143
    143144      int n = inputData.GetLength(0);
     
    157158    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
    158159      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
    159       AssertInputMatrix(inputData);
     160      RandomForestUtil.AssertInputMatrix(inputData);
    160161
    161162      int n = inputData.GetLength(0);
     
    175176    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
    176177      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
    177       AssertInputMatrix(inputData);
     178      RandomForestUtil.AssertInputMatrix(inputData);
    178179
    179180      int n = inputData.GetLength(0);
     
    315316
    316317      alglib.dfreport rep;
    317       var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
     318      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
    318319
    319320      rmsError = rep.rmserror;
     
    353354
    354355      alglib.dfreport rep;
    355       var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
     356      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
    356357
    357358      rmsError = rep.rmserror;
     
    361362
    362363      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
    363     }
    364 
    365     private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
    366       AssertParameters(r, m);
    367       AssertInputMatrix(inputMatrix);
    368 
    369       int info = 0;
    370       alglib.math.rndobject = new System.Random(seed);
    371       var dForest = new alglib.decisionforest();
    372       rep = new alglib.dfreport();
    373       int nRows = inputMatrix.GetLength(0);
    374       int nColumns = inputMatrix.GetLength(1);
    375       int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
    376       int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
    377 
    378       alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
    379       if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
    380       return dForest;
    381     }
    382 
    383     private static void AssertParameters(double r, double m) {
    384       if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
    385       if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
    386     }
    387 
    388     private static void AssertInputMatrix(double[,] inputMatrix) {
    389       if (inputMatrix.ContainsNanOrInfinity())
    390         throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
    391364    }
    392365
  • trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs

    r16565 r17154  
    2020#endregion
    2121
     22using System.Collections.Generic;
     23using System.Linq;
    2224using System.Threading;
     25using HEAL.Attic;
     26using HeuristicLab.Algorithms.DataAnalysis.RandomForest;
    2327using HeuristicLab.Common;
    2428using HeuristicLab.Core;
     
    2630using HeuristicLab.Optimization;
    2731using HeuristicLab.Parameters;
    28 using HEAL.Attic;
    2932using HeuristicLab.Problems.DataAnalysis;
    3033
     
    4346    private const string SeedParameterName = "Seed";
    4447    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
    45     private const string CreateSolutionParameterName = "CreateSolution";
     48    private const string ModelCreationParameterName = "ModelCreation";
    4649
    4750    #region parameter properties
     
    6164      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    6265    }
    63     public IFixedValueParameter<BoolValue> CreateSolutionParameter {
    64       get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
     66    private IFixedValueParameter<EnumValue<ModelCreation>> ModelCreationParameter {
     67      get { return (IFixedValueParameter<EnumValue<ModelCreation>>)Parameters[ModelCreationParameterName]; }
    6568    }
    6669    #endregion
     
    8689      set { SetSeedRandomlyParameter.Value.Value = value; }
    8790    }
    88     public bool CreateSolution {
    89       get { return CreateSolutionParameter.Value.Value; }
    90       set { CreateSolutionParameter.Value.Value = value; }
     91    public ModelCreation ModelCreation {
     92      get { return ModelCreationParameter.Value.Value; }
     93      set { ModelCreationParameter.Value.Value = value; }
    9194    }
    9295    #endregion
     
    104107      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
    105108      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    106       Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    107       Parameters[CreateSolutionParameterName].Hidden = true;
     109      Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
     110      Parameters[ModelCreationParameterName].Hidden = true;
    108111
    109112      Problem = new RegressionProblem();
     
    120123      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
    121124        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    122       if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
    123         Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    124         Parameters[CreateSolutionParameterName].Hidden = true;
     125
     126      // parameter type has been changed
     127      if (Parameters.ContainsKey("CreateSolution")) {
     128        var createSolutionParam = Parameters["CreateSolution"] as FixedValueParameter<BoolValue>;
     129        Parameters.Remove(createSolutionParam);
     130
     131        ModelCreation value = createSolutionParam.Value.Value ? ModelCreation.Model : ModelCreation.QualityOnly;
     132        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value)));
     133        Parameters[ModelCreationParameterName].Hidden = true;
     134      } else if (!Parameters.ContainsKey(ModelCreationParameterName)) {
     135        // very old version contains neither ModelCreationParameter nor CreateSolutionParameter
     136        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
     137        Parameters[ModelCreationParameterName].Hidden = true;
    125138      }
    126139      #endregion
     
    143156      Results.Add(new Result("Average relative error (out-of-bag)", "The out-of-bag average of relative errors of the random forest regression solution.", new PercentValue(outOfBagAvgRelError)));
    144157
    145       if (CreateSolution) {
    146         var solution = new RandomForestRegressionSolution(model, (IRegressionProblemData)Problem.ProblemData.Clone());
     158      IRegressionSolution solution = null;
     159      if (ModelCreation == ModelCreation.Model) {
     160        solution = model.CreateRegressionSolution(Problem.ProblemData);
     161      } else if (ModelCreation == ModelCreation.SurrogateModel) {
     162        var problemData = Problem.ProblemData;
     163        var surrogateModel = new RandomForestModelSurrogate(model, problemData.TargetVariable, problemData, Seed, NumberOfTrees, R, M);
     164        solution = surrogateModel.CreateRegressionSolution(problemData);
     165      }
     166
     167      if (solution != null) {
    147168        Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
    148169      }
    149170    }
     171
    150172
    151173    // keep for compatibility with old API
     
    157179    }
    158180
    159     public static RandomForestModel CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees,
    160       double r, double m, int seed,
    161       out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
    162       return RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed,
    163         rmsError: out rmsError, avgRelError: out avgRelError, outOfBagRmsError: out outOfBagRmsError, outOfBagAvgRelError: out outOfBagAvgRelError);
     181    public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees,
     182     double r, double m, int seed,
     183     out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
     184      var model = CreateRandomForestRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     185      return model;
     186    }
     187
     188    public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
     189    out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
     190
     191      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
     192      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
     193
     194      alglib.dfreport rep;
     195      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
     196
     197      rmsError = rep.rmserror;
     198      outOfBagRmsError = rep.oobrmserror;
     199      avgRelError = rep.avgrelerror;
     200      outOfBagAvgRelError = rep.oobavgrelerror;
     201
     202      return new RandomForestModelFull(dForest, problemData.TargetVariable, problemData.AllowedInputVariables);
    164203    }
    165204
  • trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs

    r16565 r17154  
    2727using System.Linq.Expressions;
    2828using System.Threading.Tasks;
     29using HEAL.Attic;
    2930using HeuristicLab.Common;
    3031using HeuristicLab.Core;
    3132using HeuristicLab.Data;
    3233using HeuristicLab.Parameters;
    33 using HEAL.Attic;
    3434using HeuristicLab.Problems.DataAnalysis;
    3535using HeuristicLab.Random;
     
    8989
    9090  public static class RandomForestUtil {
     91    public static void AssertParameters(double r, double m) {
     92      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
     93      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
     94    }
     95
     96    public static void AssertInputMatrix(double[,] inputMatrix) {
     97      if (inputMatrix.ContainsNanOrInfinity())
     98        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
     99    }
     100
     101    internal static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
     102      RandomForestUtil.AssertParameters(r, m);
     103      RandomForestUtil.AssertInputMatrix(inputMatrix);
     104
     105      int info = 0;
     106      alglib.math.rndobject = new System.Random(seed);
     107      var dForest = new alglib.decisionforest();
     108      rep = new alglib.dfreport();
     109      int nRows = inputMatrix.GetLength(0);
     110      int nColumns = inputMatrix.GetLength(1);
     111      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
     112      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
     113
     114      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
     115      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
     116      return dForest;
     117    }
     118
     119
    91120    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
    92121      avgTestMse = 0;
Note: See TracChangeset for help on using the changeset viewer.