Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/28/19 13:58:06 (6 years ago)
Author:
mkommend
Message:

#2952: Finished implemenation of different RF models.

Location:
branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis
Files:
7 edited
1 copied

Legend:

Unmodified
Added
Removed
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis

  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4

  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/ModelCreation.cs

    r17049 r17050  
    2323using HEAL.Attic;
    2424
    25 namespace HeuristicLab.Algorithms.DataAnalysis.GradientBoostedTrees {
     25namespace HeuristicLab.Algorithms.DataAnalysis.RandomForest {
    2626
    2727  /// <summary>
     
    3131  /// Model - the complete model will be stored (consider the amount of memory needed)
    3232  /// </summary>
    33   [StorableType("EE55C357-C4B3-4662-B40B-D1D06A851809")]
     33  [StorableType("3869899B-1848-4628-AF27-3B6FDE5840D6")]
    3434  public enum ModelCreation {
    3535    QualityOnly,
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs

    r17045 r17050  
    2020#endregion
    2121
     22using System.Collections.Generic;
     23using System.Linq;
    2224using System.Threading;
    2325using HEAL.Attic;
     26using HeuristicLab.Algorithms.DataAnalysis.RandomForest;
    2427using HeuristicLab.Common;
    2528using HeuristicLab.Core;
     
    4346    private const string SeedParameterName = "Seed";
    4447    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
    45     private const string CreateSolutionParameterName = "CreateSolution";
     48    private const string ModelCreationParameterName = "ModelCreation";
    4649
    4750    #region parameter properties
     
    6164      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    6265    }
    63     public IFixedValueParameter<BoolValue> CreateSolutionParameter {
    64       get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
     66    private IFixedValueParameter<EnumValue<ModelCreation>> ModelCreationParameter {
     67      get { return (IFixedValueParameter<EnumValue<ModelCreation>>)Parameters[ModelCreationParameterName]; }
    6568    }
    6669    #endregion
     
    8689      set { SetSeedRandomlyParameter.Value.Value = value; }
    8790    }
    88     public bool CreateSolution {
    89       get { return CreateSolutionParameter.Value.Value; }
    90       set { CreateSolutionParameter.Value.Value = value; }
     91    public ModelCreation ModelCreation {
     92      get { return ModelCreationParameter.Value.Value; }
     93      set { ModelCreationParameter.Value.Value = value; }
    9194    }
    9295    #endregion
     
    105108      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
    106109      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    107       Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    108       Parameters[CreateSolutionParameterName].Hidden = true;
     110      Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
     111      Parameters[ModelCreationParameterName].Hidden = true;
    109112
    110113      Problem = new ClassificationProblem();
     
    121124      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
    122125        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    123       if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
    124         Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    125         Parameters[CreateSolutionParameterName].Hidden = true;
     126
     127      // parameter type has been changed
     128      if (Parameters.ContainsKey("CreateSolution")) {
     129        var createSolutionParam = Parameters["CreateSolution"] as FixedValueParameter<BoolValue>;
     130        Parameters.Remove(createSolutionParam);
     131
     132        ModelCreation value = createSolutionParam.Value.Value ? ModelCreation.Model : ModelCreation.QualityOnly;
     133        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value)));
     134        Parameters[ModelCreationParameterName].Hidden = true;
    126135      }
    127136      #endregion
     
    138147
    139148      var model = CreateRandomForestClassificationModel(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
     149
    140150      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
    141151      Results.Add(new Result("Relative classification error", "Relative classification error of the random forest regression solution on the training set.", new PercentValue(relClassificationError)));
     
    143153      Results.Add(new Result("Relative classification error (out-of-bag)", "The out-of-bag relative classification error  of the random forest regression solution.", new PercentValue(outOfBagRelClassificationError)));
    144154
    145       if (CreateSolution) {
    146         var solution = model.CreateClassificationSolution(Problem.ProblemData);
     155
     156      IClassificationSolution solution = null;
     157      if (ModelCreation == ModelCreation.Model) {
     158        solution = model.CreateClassificationSolution(Problem.ProblemData);
     159      } else if (ModelCreation == ModelCreation.SurrogateModel) {
     160        var problemData = Problem.ProblemData;
     161        var surrogateModel = new RandomForestModelSurrogate(model, problemData.TargetVariable, problemData, Seed, NumberOfTrees, R, M, problemData.ClassValues.ToArray());
     162
     163        solution = surrogateModel.CreateClassificationSolution(problemData);
     164      }
     165
     166      if (solution != null) {
    147167        Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
    148168      }
     
    157177    }
    158178
    159     public static RandomForestModel CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     179    public static RandomForestModelFull CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     180 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
     181      var model = CreateRandomForestClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     182      return model;
     183    }
     184
     185    public static RandomForestModelFull CreateRandomForestClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
    160186      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
    161       return RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed,
    162        rmsError: out rmsError, relClassificationError: out relClassificationError, outOfBagRmsError: out outOfBagRmsError, outOfBagRelClassificationError: out outOfBagRelClassificationError);
     187
     188      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
     189      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
     190
     191      var classValues = problemData.ClassValues.ToArray();
     192      int nClasses = classValues.Length;
     193
     194      // map original class values to values [0..nClasses-1]
     195      var classIndices = new Dictionary<double, double>();
     196      for (int i = 0; i < nClasses; i++) {
     197        classIndices[classValues[i]] = i;
     198      }
     199
     200      int nRows = inputMatrix.GetLength(0);
     201      int nColumns = inputMatrix.GetLength(1);
     202      for (int row = 0; row < nRows; row++) {
     203        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
     204      }
     205
     206      alglib.dfreport rep;
     207      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
     208
     209      rmsError = rep.rmserror;
     210      outOfBagRmsError = rep.oobrmserror;
     211      relClassificationError = rep.relclserror;
     212      outOfBagRelClassificationError = rep.oobrelclserror;
     213
     214      return new RandomForestModelFull(dForest, problemData.TargetVariable, problemData.AllowedInputVariables, classValues);
    163215    }
    164216    #endregion
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r17045 r17050  
    3434  /// Represents a random forest model for regression and classification
    3535  /// </summary>
    36   [Obsolete("This class only exists for backwards compatibility reasons. Use RFModelSurrogate or RFModelFull instead.")]
     36  [Obsolete("This class only exists for backwards compatibility reasons for stored models with the XML Persistence. Use RFModelSurrogate or RFModelFull instead.")]
    3737  [StorableType("9AA4CCC2-CD75-4471-8DF6-949E5B783642")]
    3838  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModelFull.cs

    r17049 r17050  
    100100
    101101    public RandomForestModelFull(alglib.decisionforest decisionForest, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<double> classValues = null) : base(targetVariable) {
     102      this.name = ItemName;
     103      this.description = ItemDescription;
     104
    102105      randomForest = decisionForest;
    103106
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModelSurrogate.cs

    r17049 r17050  
    115115      IRandomForestModel randomForestModel = null;
    116116
    117       //TODO Refactor to new methods
    118117      double rmsError, oobRmsError, relClassError, oobRelClassError;
    119118      var classificationProblemData = originalTrainingData as IClassificationProblemData;
    120119
    121120      if (originalTrainingData is IRegressionProblemData regressionProblemData) {
    122         randomForestModel = RandomForestModel.CreateRegressionModel(regressionProblemData,
     121        randomForestModel = RandomForestRegression.CreateRandomForestRegressionModel(regressionProblemData,
    123122                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
    124123                                              out relClassError, out oobRelClassError);
    125124      } else if (classificationProblemData != null) {
    126         randomForestModel = RandomForestModel.CreateClassificationModel(classificationProblemData,
     125        randomForestModel = RandomForestClassification.CreateRandomForestClassificationModel(classificationProblemData,
    127126                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
    128127                                              out relClassError, out oobRelClassError);
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs

    r17045 r17050  
    2424using System.Threading;
    2525using HEAL.Attic;
     26using HeuristicLab.Algorithms.DataAnalysis.RandomForest;
    2627using HeuristicLab.Common;
    2728using HeuristicLab.Core;
     
    4546    private const string SeedParameterName = "Seed";
    4647    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
    47     private const string CreateSolutionParameterName = "CreateSolution";
     48    private const string ModelCreationParameterName = "ModelCreation";
    4849
    4950    #region parameter properties
     
    6364      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    6465    }
    65     public IFixedValueParameter<BoolValue> CreateSolutionParameter {
    66       get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
     66    private IFixedValueParameter<EnumValue<ModelCreation>> ModelCreationParameter {
     67      get { return (IFixedValueParameter<EnumValue<ModelCreation>>)Parameters[ModelCreationParameterName]; }
    6768    }
    6869    #endregion
     
    8889      set { SetSeedRandomlyParameter.Value.Value = value; }
    8990    }
    90     public bool CreateSolution {
    91       get { return CreateSolutionParameter.Value.Value; }
    92       set { CreateSolutionParameter.Value.Value = value; }
     91    public ModelCreation ModelCreation {
     92      get { return ModelCreationParameter.Value.Value; }
     93      set { ModelCreationParameter.Value.Value = value; }
    9394    }
    9495    #endregion
     
    106107      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
    107108      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    108       Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    109       Parameters[CreateSolutionParameterName].Hidden = true;
     109      Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
     110      Parameters[ModelCreationParameterName].Hidden = true;
    110111
    111112      Problem = new RegressionProblem();
     
    122123      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
    123124        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    124       if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
    125         Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    126         Parameters[CreateSolutionParameterName].Hidden = true;
     125
     126      // parameter type has been changed
     127      if (Parameters.ContainsKey("CreateSolution")) {
     128        var createSolutionParam = Parameters["CreateSolution"] as FixedValueParameter<BoolValue>;
     129        Parameters.Remove(createSolutionParam);
     130
     131        ModelCreation value = createSolutionParam.Value.Value ? ModelCreation.Model : ModelCreation.QualityOnly;
     132        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value)));
     133        Parameters[ModelCreationParameterName].Hidden = true;
    127134      }
    128135      #endregion
     
    145152      Results.Add(new Result("Average relative error (out-of-bag)", "The out-of-bag average of relative errors of the random forest regression solution.", new PercentValue(outOfBagAvgRelError)));
    146153
    147       if (CreateSolution) {
    148         var solution = model.CreateRegressionSolution(Problem.ProblemData);
     154      IRegressionSolution solution = null;
     155      if (ModelCreation == ModelCreation.Model) {
     156        solution = model.CreateRegressionSolution(Problem.ProblemData);
     157      } else if (ModelCreation == ModelCreation.SurrogateModel) {
     158        var problemData = Problem.ProblemData;
     159        var surrogateModel = new RandomForestModelSurrogate(model, problemData.TargetVariable, problemData, Seed, NumberOfTrees, R, M);
     160        solution = surrogateModel.CreateRegressionSolution(problemData);
     161      }
     162
     163      if (solution != null) {
    149164        Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
    150165      }
     
    163178     double r, double m, int seed,
    164179     out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
    165       return CreateRandomForestRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     180      var model = CreateRandomForestRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     181      return model;
    166182    }
    167183
     
    181197
    182198      return new RandomForestModelFull(dForest, problemData.TargetVariable, problemData.AllowedInputVariables);
    183 
    184       //return RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed,
    185       //rmsError: out rmsError, avgRelError: out avgRelError, outOfBagRmsError: out outOfBagRmsError, outOfBagAvgRelError: out outOfBagAvgRelError);
    186199    }
    187200
Note: See TracChangeset for help on using the changeset viewer.