Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
11/17/15 11:14:16 (9 years ago)
Author:
gkronber
Message:

#2385: added CreateSolution flag to random forest

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs

    r12504 r13204  
    4242    private const string SeedParameterName = "Seed";
    4343    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
     44    private const string CreateSolutionParameterName = "CreateSolution";
    4445
    4546    #region parameter properties
     
    5859    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
    5960      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
     61    }
     62    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
     63      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
    6064    }
    6165    #endregion
     
    8185      set { SetSeedRandomlyParameter.Value.Value = value; }
    8286    }
     87    public bool CreateSolution {
     88      get { return CreateSolutionParameter.Value.Value; }
     89      set { CreateSolutionParameter.Value.Value = value; }
     90    }
    8391    #endregion
    8492
     
    96104      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
    97105      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
     106      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     107      Parameters[CreateSolutionParameterName].Hidden = true;
     108
    98109      Problem = new ClassificationProblem();
    99110    }
     
    101112    [StorableHook(HookType.AfterDeserialization)]
    102113    private void AfterDeserialization() {
     114      // BackwardsCompatibility3.3
     115      #region Backwards compatible code, remove with 3.4
    103116      if (!Parameters.ContainsKey(MParameterName))
    104117        Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5)));
     
    107120      if (!Parameters.ContainsKey((SetSeedRandomlyParameterName)))
    108121        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
     122      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
     123        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     124        Parameters[CreateSolutionParameterName].Hidden = true;
     125      }
     126      #endregion
    109127    }
    110128
     
    118136      if (SetSeedRandomly) Seed = new System.Random().Next();
    119137
    120       var solution = CreateRandomForestClassificationSolution(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    121       Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
     138      var model = CreateRandomForestClassificationModel(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    122139      Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
    123140      Results.Add(new Result("Relative classification error", "Relative classification error of the random forest regression solution on the training set.", new PercentValue(relClassificationError)));
    124141      Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
    125142      Results.Add(new Result("Relative classification error (out-of-bag)", "The out-of-bag relative classification error  of the random forest regression solution.", new PercentValue(outOfBagRelClassificationError)));
     143
     144      if (CreateSolution) {
     145        var solution = new RandomForestClassificationSolution((IClassificationProblemData)Problem.ProblemData.Clone(), model);
     146        Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
     147      }
     148    }
     149   
     150    // keep for compatibility with old API
     151    public static RandomForestClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     152      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
     153      var model = CreateRandomForestClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
     154      return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model);
    126155    }
    127156
    128     public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     157    public static RandomForestModel CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    129158      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
    130       var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    131       return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model);
     159      return RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
    132160    }
    133161    #endregion
Note: See TracChangeset for help on using the changeset viewer.