Changeset 14029 for branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest
- Timestamp:
- 07/08/16 14:40:02 (9 years ago)
- Location:
- branches/crossvalidation-2434
- Files:
-
- 7 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/crossvalidation-2434
- Property svn:mergeinfo changed
-
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
-
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs
r12504 r14029 32 32 /// Random forest classification data analysis algorithm. 33 33 /// </summary> 34 [Item("Random Forest Classification ", "Random forest classification data analysis algorithm (wrapper for ALGLIB).")]34 [Item("Random Forest Classification (RF)", "Random forest classification data analysis algorithm (wrapper for ALGLIB).")] 35 35 [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 120)] 36 36 [StorableClass] … … 42 42 private const string SeedParameterName = "Seed"; 43 43 private const string SetSeedRandomlyParameterName = "SetSeedRandomly"; 44 private const string CreateSolutionParameterName = "CreateSolution"; 44 45 45 46 #region parameter properties … … 58 59 public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter { 59 60 get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; } 61 } 62 public IFixedValueParameter<BoolValue> CreateSolutionParameter { 63 get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; } 60 64 } 61 65 #endregion … … 81 85 set { SetSeedRandomlyParameter.Value.Value = value; } 82 86 } 87 public bool CreateSolution { 88 get { return CreateSolutionParameter.Value.Value; } 89 set { CreateSolutionParameter.Value.Value = value; } 90 } 83 91 #endregion 84 92 … … 96 104 Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0))); 97 105 Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true))); 106 Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); 107 Parameters[CreateSolutionParameterName].Hidden = true; 108 98 109 Problem = new ClassificationProblem(); 99 110 } … … 101 112 [StorableHook(HookType.AfterDeserialization)] 102 113 private void AfterDeserialization() { 114 // BackwardsCompatibility3.3 115 #region Backwards compatible code, remove with 3.4 103 116 if (!Parameters.ContainsKey(MParameterName)) 104 117 Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5))); … … 107 120 if (!Parameters.ContainsKey((SetSeedRandomlyParameterName))) 108 121 Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true))); 122 if (!Parameters.ContainsKey(CreateSolutionParameterName)) { 123 Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); 124 Parameters[CreateSolutionParameterName].Hidden = true; 125 } 126 #endregion 109 127 } 110 128 … … 118 136 if (SetSeedRandomly) Seed = new System.Random().Next(); 119 137 120 var solution = CreateRandomForestClassificationSolution(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); 121 Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution)); 138 var model = CreateRandomForestClassificationModel(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); 122 139 Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError))); 123 140 Results.Add(new Result("Relative classification error", "Relative classification error of the random forest regression solution on the training set.", new PercentValue(relClassificationError))); 124 141 Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError))); 125 142 Results.Add(new Result("Relative classification error (out-of-bag)", "The out-of-bag relative classification error of the random forest regression solution.", new PercentValue(outOfBagRelClassificationError))); 143 144 if (CreateSolution) { 145 var solution = new RandomForestClassificationSolution(model, (IClassificationProblemData)Problem.ProblemData.Clone()); 146 Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution)); 147 } 126 148 } 127 149 128 public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 150 // keep for compatibility with old API 151 public static RandomForestClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 129 152 out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) { 130 var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); 131 return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model); 153 var model = CreateRandomForestClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); 154 return new RandomForestClassificationSolution(model, (IClassificationProblemData)problemData.Clone()); 155 } 156 157 public static RandomForestModel CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 158 out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) { 159 return RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); 132 160 } 133 161 #endregion -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassificationSolution.cs
r12012 r14029 43 43 : base(original, cloner) { 44 44 } 45 public RandomForestClassificationSolution(I ClassificationProblemData problemData, IRandomForestModel randomForestModel)45 public RandomForestClassificationSolution(IRandomForestModel randomForestModel, IClassificationProblemData problemData) 46 46 : base(randomForestModel, problemData) { 47 47 } -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
r12509 r14029 34 34 [StorableClass] 35 35 [Item("RandomForestModel", "Represents a random forest for regression and classification.")] 36 public sealed class RandomForestModel : NamedItem, IRandomForestModel {36 public sealed class RandomForestModel : ClassificationModel, IRandomForestModel { 37 37 // not persisted 38 38 private alglib.decisionforest randomForest; … … 44 44 } 45 45 } 46 47 public override IEnumerable<string> VariablesUsedForPrediction { 48 get { return originalTrainingData.AllowedInputVariables; } 49 } 50 46 51 47 52 // instead of storing the data of the model itself … … 91 96 92 97 // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel 93 private RandomForestModel( alglib.decisionforest randomForest,98 private RandomForestModel(string targetVariable, alglib.decisionforest randomForest, 94 99 int seed, IDataAnalysisProblemData originalTrainingData, 95 100 int nTrees, double r, double m, double[] classValues = null) 96 : base( ) {101 : base(targetVariable) { 97 102 this.name = ItemName; 98 103 this.description = ItemDescription; … … 147 152 } 148 153 149 public IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {154 public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) { 150 155 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows); 151 156 AssertInputMatrix(inputData); … … 174 179 } 175 180 176 public IRandomForestRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 177 return new RandomForestRegressionSolution(new RegressionProblemData(problemData), this); 178 } 179 IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) { 180 return CreateRegressionSolution(problemData); 181 } 182 public IRandomForestClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { 183 return new RandomForestClassificationSolution(new ClassificationProblemData(problemData), this); 184 } 185 IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) { 186 return CreateClassificationSolution(problemData); 181 182 public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 183 return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData)); 184 } 185 public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { 186 return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData)); 187 187 } 188 188 … … 205 205 outOfBagRmsError = rep.oobrmserror; 206 206 207 return new RandomForestModel( dForest, seed, problemData, nTrees, r, m);207 return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m); 208 208 } 209 209 … … 242 242 outOfBagRelClassificationError = rep.oobrelclserror; 243 243 244 return new RandomForestModel( dForest, seed, problemData, nTrees, r, m, classValues);244 return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues); 245 245 } 246 246 -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs
r12504 r14029 32 32 /// Random forest regression data analysis algorithm. 33 33 /// </summary> 34 [Item("Random Forest Regression ", "Random forest regression data analysis algorithm (wrapper for ALGLIB).")]34 [Item("Random Forest Regression (RF)", "Random forest regression data analysis algorithm (wrapper for ALGLIB).")] 35 35 [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)] 36 36 [StorableClass] … … 42 42 private const string SeedParameterName = "Seed"; 43 43 private const string SetSeedRandomlyParameterName = "SetSeedRandomly"; 44 private const string CreateSolutionParameterName = "CreateSolution"; 44 45 45 46 #region parameter properties … … 58 59 public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter { 59 60 get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; } 61 } 62 public IFixedValueParameter<BoolValue> CreateSolutionParameter { 63 get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; } 60 64 } 61 65 #endregion … … 81 85 set { SetSeedRandomlyParameter.Value.Value = value; } 82 86 } 87 public bool CreateSolution { 88 get { return CreateSolutionParameter.Value.Value; } 89 set { CreateSolutionParameter.Value.Value = value; } 90 } 83 91 #endregion 84 92 [StorableConstructor] … … 95 103 Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0))); 96 104 Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true))); 105 Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); 106 Parameters[CreateSolutionParameterName].Hidden = true; 107 97 108 Problem = new RegressionProblem(); 98 109 } … … 100 111 [StorableHook(HookType.AfterDeserialization)] 101 112 private void AfterDeserialization() { 113 // BackwardsCompatibility3.3 114 #region Backwards compatible code, remove with 3.4 102 115 if (!Parameters.ContainsKey(MParameterName)) 103 116 Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.5))); … … 106 119 if (!Parameters.ContainsKey((SetSeedRandomlyParameterName))) 107 120 Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true))); 121 if (!Parameters.ContainsKey(CreateSolutionParameterName)) { 122 Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); 123 Parameters[CreateSolutionParameterName].Hidden = true; 124 } 125 #endregion 108 126 } 109 127 … … 116 134 double rmsError, avgRelError, outOfBagRmsError, outOfBagAvgRelError; 117 135 if (SetSeedRandomly) Seed = new System.Random().Next(); 136 var model = CreateRandomForestRegressionModel(Problem.ProblemData, NumberOfTrees, R, M, Seed, 137 out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 118 138 119 var solution = CreateRandomForestRegressionSolution(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);120 Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));121 139 Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError))); 122 140 Results.Add(new Result("Average relative error", "The average of relative errors of the random forest regression solution on the training set.", new PercentValue(avgRelError))); 123 141 Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError))); 124 142 Results.Add(new Result("Average relative error (out-of-bag)", "The out-of-bag average of relative errors of the random forest regression solution.", new PercentValue(outOfBagAvgRelError))); 143 144 if (CreateSolution) { 145 var solution = new RandomForestRegressionSolution(model, (IRegressionProblemData)Problem.ProblemData.Clone()); 146 Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution)); 147 } 125 148 } 126 149 127 public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 150 // keep for compatibility with old API 151 public static RandomForestRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 128 152 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) { 129 var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 130 return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), model); 153 var model = CreateRandomForestRegressionModel(problemData, nTrees, r, m, seed, 154 out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 155 return new RandomForestRegressionSolution(model, (IRegressionProblemData)problemData.Clone()); 131 156 } 157 158 public static RandomForestModel CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees, 159 double r, double m, int seed, 160 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) { 161 return RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 162 } 163 132 164 #endregion 133 165 } -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegressionSolution.cs
r12012 r14029 43 43 : base(original, cloner) { 44 44 } 45 public RandomForestRegressionSolution(IR egressionProblemData problemData, IRandomForestModel randomForestModel)45 public RandomForestRegressionSolution(IRandomForestModel randomForestModel, IRegressionProblemData problemData) 46 46 : base(randomForestModel, problemData) { 47 47 }
Note: See TracChangeset
for help on using the changeset viewer.