Changeset 13204 for trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs
- Timestamp:
- 11/17/15 11:14:16 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs
r12504 r13204 42 42 private const string SeedParameterName = "Seed"; 43 43 private const string SetSeedRandomlyParameterName = "SetSeedRandomly"; 44 private const string CreateSolutionParameterName = "CreateSolution"; 44 45 45 46 #region parameter properties … … 58 59 public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter { 59 60 get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; } 61 } 62 public IFixedValueParameter<BoolValue> CreateSolutionParameter { 63 get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; } 60 64 } 61 65 #endregion … … 81 85 set { SetSeedRandomlyParameter.Value.Value = value; } 82 86 } 87 public bool CreateSolution { 88 get { return CreateSolutionParameter.Value.Value; } 89 set { CreateSolutionParameter.Value.Value = value; } 90 } 83 91 #endregion 84 92 … … 96 104 Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0))); 97 105 Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true))); 106 Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); 107 Parameters[CreateSolutionParameterName].Hidden = true; 108 98 109 Problem = new ClassificationProblem(); 99 110 } … … 101 112 [StorableHook(HookType.AfterDeserialization)] 102 113 private void AfterDeserialization() { 114 // BackwardsCompatibility3.3 115 #region Backwards compatible code, remove with 3.4 103 116 if (!Parameters.ContainsKey(MParameterName)) 104 117 Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5))); … … 107 120 if (!Parameters.ContainsKey((SetSeedRandomlyParameterName))) 108 121 Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true))); 122 if (!Parameters.ContainsKey(CreateSolutionParameterName)) { 123 Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); 124 Parameters[CreateSolutionParameterName].Hidden = true; 125 } 126 #endregion 109 127 } 110 128 … … 118 136 if (SetSeedRandomly) Seed = new System.Random().Next(); 119 137 120 var solution = CreateRandomForestClassificationSolution(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); 121 Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution)); 138 var model = CreateRandomForestClassificationModel(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); 122 139 Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError))); 123 140 Results.Add(new Result("Relative classification error", "Relative classification error of the random forest regression solution on the training set.", new PercentValue(relClassificationError))); 124 141 Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError))); 125 142 Results.Add(new Result("Relative classification error (out-of-bag)", "The out-of-bag relative classification error of the random forest regression solution.", new PercentValue(outOfBagRelClassificationError))); 143 144 if (CreateSolution) { 145 var solution = new RandomForestClassificationSolution((IClassificationProblemData)Problem.ProblemData.Clone(), model); 146 Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution)); 147 } 148 } 149 150 // keep for compatibility with old API 151 public static RandomForestClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 152 out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) { 153 var model = CreateRandomForestClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); 154 return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model); 126 155 } 127 156 128 public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,157 public static RandomForestModel CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 129 158 out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) { 130 var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); 131 return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model); 159 return RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); 132 160 } 133 161 #endregion
Note: See TracChangeset
for help on using the changeset viewer.