- Timestamp:
- 06/28/16 13:33:17 (8 years ago)
- Location:
- trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest
- Files:
-
- 5 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs
r13238 r13941 143 143 144 144 if (CreateSolution) { 145 var solution = new RandomForestClassificationSolution( (IClassificationProblemData)Problem.ProblemData.Clone(), model);145 var solution = new RandomForestClassificationSolution(model, (IClassificationProblemData)Problem.ProblemData.Clone()); 146 146 Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution)); 147 147 } 148 148 } 149 149 150 150 // keep for compatibility with old API 151 151 public static RandomForestClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 152 152 out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) { 153 153 var model = CreateRandomForestClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); 154 return new RandomForestClassificationSolution( (IClassificationProblemData)problemData.Clone(), model);154 return new RandomForestClassificationSolution(model, (IClassificationProblemData)problemData.Clone()); 155 155 } 156 156 -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassificationSolution.cs
r12012 r13941 43 43 : base(original, cloner) { 44 44 } 45 public RandomForestClassificationSolution(I ClassificationProblemData problemData, IRandomForestModel randomForestModel)45 public RandomForestClassificationSolution(IRandomForestModel randomForestModel, IClassificationProblemData problemData) 46 46 : base(randomForestModel, problemData) { 47 47 } -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
r13921 r13941 34 34 [StorableClass] 35 35 [Item("RandomForestModel", "Represents a random forest for regression and classification.")] 36 public sealed class RandomForestModel : NamedItem, IRandomForestModel {36 public sealed class RandomForestModel : ClassificationModel, IRandomForestModel { 37 37 // not persisted 38 38 private alglib.decisionforest randomForest; … … 45 45 } 46 46 47 public IEnumerable<string> VariablesUsedForPrediction {47 public override IEnumerable<string> VariablesUsedForPrediction { 48 48 get { return originalTrainingData.AllowedInputVariables; } 49 49 } 50 50 51 public string TargetVariable {52 get {53 var regressionProblemData = originalTrainingData as IRegressionProblemData;54 var classificationProblemData = originalTrainingData as IClassificationProblemData;55 if (classificationProblemData != null)56 return classificationProblemData.TargetVariable;57 if (regressionProblemData != null)58 return regressionProblemData.TargetVariable;59 throw new InvalidOperationException("Getting the target variable requires either a regression or a classification problem data.");60 }61 }62 51 63 52 // instead of storing the data of the model itself … … 107 96 108 97 // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel 109 private RandomForestModel( alglib.decisionforest randomForest,98 private RandomForestModel(string targetVariable, alglib.decisionforest randomForest, 110 99 int seed, IDataAnalysisProblemData originalTrainingData, 111 100 int nTrees, double r, double m, double[] classValues = null) 112 : base( ) {101 : base(targetVariable) { 113 102 this.name = ItemName; 114 103 this.description = ItemDescription; … … 163 152 } 164 153 165 public IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {154 public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) { 166 155 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows); 167 156 AssertInputMatrix(inputData); … … 190 179 } 191 180 192 public IRandomForestRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 193 return new RandomForestRegressionSolution(new RegressionProblemData(problemData), this); 194 } 195 IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) { 196 return CreateRegressionSolution(problemData); 197 } 198 public IRandomForestClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { 199 return new RandomForestClassificationSolution(new ClassificationProblemData(problemData), this); 200 } 201 IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) { 202 return CreateClassificationSolution(problemData); 181 182 public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 183 return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData)); 184 } 185 public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { 186 return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData)); 203 187 } 204 188 … … 221 205 outOfBagRmsError = rep.oobrmserror; 222 206 223 return new RandomForestModel( dForest, seed, problemData, nTrees, r, m);207 return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m); 224 208 } 225 209 … … 258 242 outOfBagRelClassificationError = rep.oobrelclserror; 259 243 260 return new RandomForestModel( dForest, seed, problemData, nTrees, r, m, classValues);244 return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues); 261 245 } 262 246 -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs
r13238 r13941 143 143 144 144 if (CreateSolution) { 145 var solution = new RandomForestRegressionSolution( (IRegressionProblemData)Problem.ProblemData.Clone(), model);145 var solution = new RandomForestRegressionSolution(model, (IRegressionProblemData)Problem.ProblemData.Clone()); 146 146 Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution)); 147 147 } … … 153 153 var model = CreateRandomForestRegressionModel(problemData, nTrees, r, m, seed, 154 154 out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 155 return new RandomForestRegressionSolution( (IRegressionProblemData)problemData.Clone(), model);155 return new RandomForestRegressionSolution(model, (IRegressionProblemData)problemData.Clone()); 156 156 } 157 157 -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegressionSolution.cs
r12012 r13941 43 43 : base(original, cloner) { 44 44 } 45 public RandomForestRegressionSolution(IR egressionProblemData problemData, IRandomForestModel randomForestModel)45 public RandomForestRegressionSolution(IRandomForestModel randomForestModel, IRegressionProblemData problemData) 46 46 : base(randomForestModel, problemData) { 47 47 }
Note: See TracChangeset
for help on using the changeset viewer.