Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/28/19 13:58:06 (5 years ago)
Author:
mkommend
Message:

#2952: Finished implemenation of different RF models.

Location:
branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis

  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4

  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs

    r17045 r17050  
    2020#endregion
    2121
     22using System.Collections.Generic;
     23using System.Linq;
    2224using System.Threading;
    2325using HEAL.Attic;
     26using HeuristicLab.Algorithms.DataAnalysis.RandomForest;
    2427using HeuristicLab.Common;
    2528using HeuristicLab.Core;
     
    4346    private const string SeedParameterName = "Seed";
    4447    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
    45     private const string CreateSolutionParameterName = "CreateSolution";
     48    private const string ModelCreationParameterName = "ModelCreation";
    4649
    4750    #region parameter properties
     
    6164      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    6265    }
    63     public IFixedValueParameter<BoolValue> CreateSolutionParameter {
    64       get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
     66    private IFixedValueParameter<EnumValue<ModelCreation>> ModelCreationParameter {
     67      get { return (IFixedValueParameter<EnumValue<ModelCreation>>)Parameters[ModelCreationParameterName]; }
    6568    }
    6669    #endregion
     
    8689      set { SetSeedRandomlyParameter.Value.Value = value; }
    8790    }
    88     public bool CreateSolution {
    89       get { return CreateSolutionParameter.Value.Value; }
    90       set { CreateSolutionParameter.Value.Value = value; }
     91    public ModelCreation ModelCreation {
     92      get { return ModelCreationParameter.Value.Value; }
     93      set { ModelCreationParameter.Value.Value = value; }
    9194    }
    9295    #endregion
     
    105108      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
    106109      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    107       Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    108       Parameters[CreateSolutionParameterName].Hidden = true;
     110      Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
     111      Parameters[ModelCreationParameterName].Hidden = true;
    109112
    110113      Problem = new ClassificationProblem();
     
    121124      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
    122125        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    123       if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
    124         Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    125         Parameters[CreateSolutionParameterName].Hidden = true;
     126
     127      // parameter type has been changed
     128      if (Parameters.ContainsKey("CreateSolution")) {
     129        var createSolutionParam = Parameters["CreateSolution"] as FixedValueParameter<BoolValue>;
     130        Parameters.Remove(createSolutionParam);
     131
     132        ModelCreation value = createSolutionParam.Value.Value ? ModelCreation.Model : ModelCreation.QualityOnly;
     133        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value)));
     134        Parameters[ModelCreationParameterName].Hidden = true;
    126135      }
    127136      #endregion
     
    138147
    139148      var model = CreateRandomForestClassificationModel(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
     149
    140150      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
    141151      Results.Add(new Result("Relative classification error", "Relative classification error of the random forest regression solution on the training set.", new PercentValue(relClassificationError)));
     
    143153      Results.Add(new Result("Relative classification error (out-of-bag)", "The out-of-bag relative classification error  of the random forest regression solution.", new PercentValue(outOfBagRelClassificationError)));
    144154
    145       if (CreateSolution) {
    146         var solution = model.CreateClassificationSolution(Problem.ProblemData);
     155
     156      IClassificationSolution solution = null;
     157      if (ModelCreation == ModelCreation.Model) {
     158        solution = model.CreateClassificationSolution(Problem.ProblemData);
     159      } else if (ModelCreation == ModelCreation.SurrogateModel) {
     160        var problemData = Problem.ProblemData;
     161        var surrogateModel = new RandomForestModelSurrogate(model, problemData.TargetVariable, problemData, Seed, NumberOfTrees, R, M, problemData.ClassValues.ToArray());
     162
     163        solution = surrogateModel.CreateClassificationSolution(problemData);
     164      }
     165
     166      if (solution != null) {
    147167        Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
    148168      }
     
    157177    }
    158178
    159     public static RandomForestModel CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     179    public static RandomForestModelFull CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     180 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
     181      var model = CreateRandomForestClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     182      return model;
     183    }
     184
     185    public static RandomForestModelFull CreateRandomForestClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
    160186      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
    161       return RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed,
    162        rmsError: out rmsError, relClassificationError: out relClassificationError, outOfBagRmsError: out outOfBagRmsError, outOfBagRelClassificationError: out outOfBagRelClassificationError);
     187
     188      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
     189      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
     190
     191      var classValues = problemData.ClassValues.ToArray();
     192      int nClasses = classValues.Length;
     193
     194      // map original class values to values [0..nClasses-1]
     195      var classIndices = new Dictionary<double, double>();
     196      for (int i = 0; i < nClasses; i++) {
     197        classIndices[classValues[i]] = i;
     198      }
     199
     200      int nRows = inputMatrix.GetLength(0);
     201      int nColumns = inputMatrix.GetLength(1);
     202      for (int row = 0; row < nRows; row++) {
     203        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
     204      }
     205
     206      alglib.dfreport rep;
     207      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
     208
     209      rmsError = rep.rmserror;
     210      outOfBagRmsError = rep.oobrmserror;
     211      relClassificationError = rep.relclserror;
     212      outOfBagRelClassificationError = rep.oobrelclserror;
     213
     214      return new RandomForestModelFull(dForest, problemData.TargetVariable, problemData.AllowedInputVariables, classValues);
    163215    }
    164216    #endregion
Note: See TracChangeset for help on using the changeset viewer.