Changeset 13204


Ignore:
Timestamp:
11/17/15 11:14:16 (4 years ago)
Author:
gkronber
Message:

#2385: added CreateSolution flag to random forest

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs

    r12504 r13204  
    4242    private const string SeedParameterName = "Seed";
    4343    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
     44    private const string CreateSolutionParameterName = "CreateSolution";
    4445
    4546    #region parameter properties
     
    5859    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
    5960      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
     61    }
     62    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
     63      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
    6064    }
    6165    #endregion
     
    8185      set { SetSeedRandomlyParameter.Value.Value = value; }
    8286    }
     87    public bool CreateSolution {
     88      get { return CreateSolutionParameter.Value.Value; }
     89      set { CreateSolutionParameter.Value.Value = value; }
     90    }
    8391    #endregion
    8492
     
    96104      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
    97105      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
     106      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     107      Parameters[CreateSolutionParameterName].Hidden = true;
     108
    98109      Problem = new ClassificationProblem();
    99110    }
     
    101112    [StorableHook(HookType.AfterDeserialization)]
    102113    private void AfterDeserialization() {
     114      // BackwardsCompatibility3.3
     115      #region Backwards compatible code, remove with 3.4
    103116      if (!Parameters.ContainsKey(MParameterName))
    104117        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
     
    107120      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
    108121        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
     122      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
     123        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     124        Parameters[CreateSolutionParameterName].Hidden = true;
     125      }
     126      #endregion
    109127    }
    110128
     
    118136      if (SetSeedRandomly) Seed = new System.Random().Next();
    119137
    120       var solution = CreateRandomForestClassificationSolution(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    121       Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
     138      var model = CreateRandomForestClassificationModel(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    122139      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
    123140      Results.Add(new Result("Relative classification error", "Relative classification error of the random forest regression solution on the training set.", new PercentValue(relClassificationError)));
    124141      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
    125142      Results.Add(new Result("Relative classification error (out-of-bag)", "The out-of-bag relative classification error  of the random forest regression solution.", new PercentValue(outOfBagRelClassificationError)));
     143
     144      if (CreateSolution) {
     145        var solution = new RandomForestClassificationSolution((IClassificationProblemData)Problem.ProblemData.Clone(), model);
     146        Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
     147      }
     148    }
     149   
     150    // keep for compatibility with old API
     151    public static RandomForestClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     152      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
     153      var model = CreateRandomForestClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
     154      return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model);
    126155    }
    127156
    128     public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     157    public static RandomForestModel CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    129158      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
    130       var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    131       return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model);
     159      return RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    132160    }
    133161    #endregion
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs

    r12504 r13204  
    4242    private const string SeedParameterName = "Seed";
    4343    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
     44    private const string CreateSolutionParameterName = "CreateSolution";
    4445
    4546    #region parameter properties
     
    5859    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
    5960      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
     61    }
     62    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
     63      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
    6064    }
    6165    #endregion
     
    8185      set { SetSeedRandomlyParameter.Value.Value = value; }
    8286    }
     87    public bool CreateSolution {
     88      get { return CreateSolutionParameter.Value.Value; }
     89      set { CreateSolutionParameter.Value.Value = value; }
     90    }
    8391    #endregion
    8492    [StorableConstructor]
     
    95103      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
    96104      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
     105      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     106      Parameters[CreateSolutionParameterName].Hidden = true;
     107
    97108      Problem = new RegressionProblem();
    98109    }
     
    100111    [StorableHook(HookType.AfterDeserialization)]
    101112    private void AfterDeserialization() {
     113      // BackwardsCompatibility3.3
     114      #region Backwards compatible code, remove with 3.4
    102115      if (!Parameters.ContainsKey(MParameterName))
    103116        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
     
    106119      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
    107120        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
     121      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
     122        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     123        Parameters[CreateSolutionParameterName].Hidden = true;
     124      }
     125      #endregion
    108126    }
    109127
     
    116134      double rmsError, avgRelError, outOfBagRmsError, outOfBagAvgRelError;
    117135      if (SetSeedRandomly) Seed = new System.Random().Next();
     136      var model = CreateRandomForestRegressionModel(Problem.ProblemData, NumberOfTrees, R, M, Seed,
     137        out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
    118138
    119       var solution = CreateRandomForestRegressionSolution(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
    120       Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
    121139      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
    122140      Results.Add(new Result("Average relative error", "The average of relative errors of the random forest regression solution on the training set.", new PercentValue(avgRelError)));
    123141      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
    124142      Results.Add(new Result("Average relative error (out-of-bag)", "The out-of-bag average of relative errors of the random forest regression solution.", new PercentValue(outOfBagAvgRelError)));
     143
     144      if (CreateSolution) {
     145        var solution = new RandomForestRegressionSolution((IRegressionProblemData)Problem.ProblemData.Clone(), model);
     146        Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
     147      }
    125148    }
    126149
    127     public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
     150    // keep for compatibility with old API
     151    public static RandomForestRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
    128152      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
    129       var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     153      var model = CreateRandomForestRegressionModel(problemData, nTrees, r, m, seed,
     154        out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
    130155      return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), model);
    131156    }
     157
     158    public static RandomForestModel CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees,
     159      double r, double m, int seed,
     160      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
     161      return RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     162    }
     163
    132164    #endregion
    133165  }
Note: See TracChangeset for help on using the changeset viewer.