- Timestamp:
- 11/27/14 11:23:37 (10 years ago)
- Location:
- branches/Breadcrumbs
- Files:
-
- 7 edited
- 1 copied
Legend:
- Unmodified
- Added
- Removed
-
branches/Breadcrumbs
- Property svn:ignore
-
old new 8 8 FxCopResults.txt 9 9 Google.ProtocolBuffers-0.9.1.dll 10 Google.ProtocolBuffers-2.4.1.473.dll 10 11 HeuristicLab 3.3.5.1.ReSharper.user 11 12 HeuristicLab 3.3.6.0.ReSharper.user 12 13 HeuristicLab.4.5.resharper.user 13 14 HeuristicLab.ExtLibs.6.0.ReSharper.user 15 HeuristicLab.Scripting.Development 14 16 HeuristicLab.resharper.user 15 17 ProtoGen.exe … … 17 19 _ReSharper.HeuristicLab 18 20 _ReSharper.HeuristicLab 3.3 21 _ReSharper.HeuristicLab 3.3 Tests 19 22 _ReSharper.HeuristicLab.ExtLibs 20 23 bin 21 24 protoc.exe 22 _ReSharper.HeuristicLab 3.3 Tests23 Google.ProtocolBuffers-2.4.1.473.dll
-
- Property svn:mergeinfo changed
- Property svn:ignore
-
branches/Breadcrumbs/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
-
branches/Breadcrumbs/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs
r9456 r11594 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 3Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. … … 131 131 public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 132 132 out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) { 133 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1."); 134 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1."); 135 136 alglib.math.rndobject = new System.Random(seed); 137 138 Dataset dataset = problemData.Dataset; 139 string targetVariable = problemData.TargetVariable; 140 IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables; 141 IEnumerable<int> rows = problemData.TrainingIndices; 142 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows); 143 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x))) 144 throw new NotSupportedException("Random forest classification does not support NaN or infinity values in the input dataset."); 145 146 int info = 0; 147 alglib.decisionforest dForest = new alglib.decisionforest(); 148 alglib.dfreport rep = new alglib.dfreport(); ; 149 int nRows = inputMatrix.GetLength(0); 150 int nColumns = inputMatrix.GetLength(1); 151 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1); 152 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); 153 154 155 double[] classValues = problemData.ClassValues.ToArray(); 156 int nClasses = problemData.Classes; 157 // map original class values to values [0..nClasses-1] 158 Dictionary<double, double> classIndices = new Dictionary<double, double>(); 159 for (int i = 0; i < nClasses; i++) { 160 classIndices[classValues[i]] = i; 161 } 162 for (int row = 0; row < nRows; row++) { 163 inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]]; 164 } 165 // execute random forest algorithm 166 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); 167 if (info != 1) throw new ArgumentException("Error in calculation of random forest classification solution"); 168 169 rmsError = rep.rmserror; 170 outOfBagRmsError = rep.oobrmserror; 171 relClassificationError = rep.relclserror; 172 outOfBagRelClassificationError = rep.oobrelclserror; 173 return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables, classValues)); 133 var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); 134 return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model); 174 135 } 175 136 #endregion -
branches/Breadcrumbs/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassificationSolution.cs
r9456 r11594 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 3Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. -
branches/Breadcrumbs/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
r9456 r11594 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 3Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. … … 35 35 [Item("RandomForestModel", "Represents a random forest for regression and classification.")] 36 36 public sealed class RandomForestModel : NamedItem, IRandomForestModel { 37 37 // not persisted 38 38 private alglib.decisionforest randomForest; 39 public alglib.decisionforest RandomForest { 40 get { return randomForest; } 41 set { 42 if (value != randomForest) { 43 if (value == null) throw new ArgumentNullException(); 44 randomForest = value; 45 OnChanged(EventArgs.Empty); 46 } 47 } 48 } 49 50 [Storable] 51 private string targetVariable; 52 [Storable] 53 private string[] allowedInputVariables; 39 private alglib.decisionforest RandomForest { 40 get { 41 // recalculate lazily 42 if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel(); 43 return randomForest; 44 } 45 } 46 47 // instead of storing the data of the model itself 48 // we instead only store data necessary to recalculate the same model lazily on demand 49 [Storable] 50 private int seed; 51 [Storable] 52 private IDataAnalysisProblemData originalTrainingData; 54 53 [Storable] 55 54 private double[] classValues; 55 [Storable] 56 private int nTrees; 57 [Storable] 58 private double r; 59 [Storable] 60 private double m; 61 62 56 63 [StorableConstructor] 57 64 private RandomForestModel(bool deserializing) 58 65 : base(deserializing) { 59 if (deserializing)60 66 // for backwards compatibility (loading old solutions) 67 randomForest = new alglib.decisionforest(); 61 68 } 62 69 private RandomForestModel(RandomForestModel original, Cloner cloner) … … 67 74 randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees; 68 75 randomForest.innerobj.nvars = original.randomForest.innerobj.nvars; 69 randomForest.innerobj.trees = (double[])original.randomForest.innerobj.trees.Clone(); 70 targetVariable = original.targetVariable; 71 allowedInputVariables = (string[])original.allowedInputVariables.Clone(); 72 if (original.classValues != null) 73 this.classValues = (double[])original.classValues.Clone(); 74 } 75 public RandomForestModel(alglib.decisionforest randomForest, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null) 76 // we assume that the trees array (double[]) is immutable in alglib 77 randomForest.innerobj.trees = original.randomForest.innerobj.trees; 78 79 // allowedInputVariables is immutable so we don't need to clone 80 allowedInputVariables = original.allowedInputVariables; 81 82 // clone data which is necessary to rebuild the model 83 this.seed = original.seed; 84 this.originalTrainingData = cloner.Clone(original.originalTrainingData); 85 // classvalues is immutable so we don't need to clone 86 this.classValues = original.classValues; 87 this.nTrees = original.nTrees; 88 this.r = original.r; 89 this.m = original.m; 90 } 91 92 // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel 93 private RandomForestModel(alglib.decisionforest randomForest, 94 int seed, IDataAnalysisProblemData originalTrainingData, 95 int nTrees, double r, double m, double[] classValues = null) 76 96 : base() { 77 97 this.name = ItemName; 78 98 this.description = ItemDescription; 99 // the model itself 79 100 this.randomForest = randomForest; 80 this.targetVariable = targetVariable; 81 this.allowedInputVariables = allowedInputVariables.ToArray(); 82 if (classValues != null) 83 this.classValues = (double[])classValues.Clone(); 101 // data which is necessary for recalculation of the model 102 this.seed = seed; 103 this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone(); 104 this.classValues = classValues; 105 this.nTrees = nTrees; 106 this.r = r; 107 this.m = m; 84 108 } 85 109 … … 88 112 } 89 113 114 private void RecalculateModel() { 115 double rmsError, oobRmsError, relClassError, oobRelClassError; 116 var regressionProblemData = originalTrainingData as IRegressionProblemData; 117 var classificationProblemData = originalTrainingData as IClassificationProblemData; 118 if (regressionProblemData != null) { 119 var model = CreateRegressionModel(regressionProblemData, 120 nTrees, r, m, seed, out rmsError, out oobRmsError, 121 out relClassError, out oobRelClassError); 122 randomForest = model.randomForest; 123 } else if (classificationProblemData != null) { 124 var model = CreateClassificationModel(classificationProblemData, 125 nTrees, r, m, seed, out rmsError, out oobRmsError, 126 out relClassError, out oobRelClassError); 127 randomForest = model.randomForest; 128 } 129 } 130 90 131 public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) { 91 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows); 132 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows); 133 AssertInputMatrix(inputData); 92 134 93 135 int n = inputData.GetLength(0); … … 100 142 x[column] = inputData[row, column]; 101 143 } 102 alglib.dfprocess( randomForest, x, ref y);144 alglib.dfprocess(RandomForest, x, ref y); 103 145 yield return y[0]; 104 146 } … … 106 148 107 149 public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 108 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows); 150 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows); 151 AssertInputMatrix(inputData); 109 152 110 153 int n = inputData.GetLength(0); 111 154 int columns = inputData.GetLength(1); 112 155 double[] x = new double[columns]; 113 double[] y = new double[ randomForest.innerobj.nclasses];156 double[] y = new double[RandomForest.innerobj.nclasses]; 114 157 115 158 for (int row = 0; row < n; row++) { … … 144 187 } 145 188 146 #region events 147 public event EventHandler Changed; 148 private void OnChanged(EventArgs e) { 149 var handlers = Changed; 150 if (handlers != null) 151 handlers(this, e); 152 } 153 #endregion 154 155 #region persistence 189 public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 190 out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) { 191 return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError); 192 } 193 194 public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed, 195 out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) { 196 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 197 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices); 198 199 alglib.dfreport rep; 200 var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep); 201 202 rmsError = rep.rmserror; 203 avgRelError = rep.avgrelerror; 204 outOfBagAvgRelError = rep.oobavgrelerror; 205 outOfBagRmsError = rep.oobrmserror; 206 207 return new RandomForestModel(dForest,seed, problemData,nTrees, r, m); 208 } 209 210 public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 211 out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) { 212 return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError); 213 } 214 215 public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed, 216 out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) { 217 218 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 219 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices); 220 221 var classValues = problemData.ClassValues.ToArray(); 222 int nClasses = classValues.Length; 223 224 // map original class values to values [0..nClasses-1] 225 var classIndices = new Dictionary<double, double>(); 226 for (int i = 0; i < nClasses; i++) { 227 classIndices[classValues[i]] = i; 228 } 229 230 int nRows = inputMatrix.GetLength(0); 231 int nColumns = inputMatrix.GetLength(1); 232 for (int row = 0; row < nRows; row++) { 233 inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]]; 234 } 235 236 alglib.dfreport rep; 237 var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep); 238 239 rmsError = rep.rmserror; 240 outOfBagRmsError = rep.oobrmserror; 241 relClassificationError = rep.relclserror; 242 outOfBagRelClassificationError = rep.oobrelclserror; 243 244 return new RandomForestModel(dForest,seed, problemData,nTrees, r, m, classValues); 245 } 246 247 private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) { 248 AssertParameters(r, m); 249 AssertInputMatrix(inputMatrix); 250 251 int info = 0; 252 alglib.math.rndobject = new System.Random(seed); 253 var dForest = new alglib.decisionforest(); 254 rep = new alglib.dfreport(); 255 int nRows = inputMatrix.GetLength(0); 256 int nColumns = inputMatrix.GetLength(1); 257 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1); 258 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); 259 260 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); 261 if (info != 1) throw new ArgumentException("Error in calculation of random forest model"); 262 return dForest; 263 } 264 265 private static void AssertParameters(double r, double m) { 266 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1."); 267 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1."); 268 } 269 270 private static void AssertInputMatrix(double[,] inputMatrix) { 271 if (inputMatrix.Cast<double>().Any(x => Double.IsNaN(x) || Double.IsInfinity(x))) 272 throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset."); 273 } 274 275 #region persistence for backwards compatibility 276 // when the originalTrainingData is null this means the model was loaded from an old file 277 // therefore, we cannot use the new persistence mechanism because the original data is not available anymore 278 // in such cases we still store the compete model 279 private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } } 280 281 private string[] allowedInputVariables; 282 [Storable(Name = "allowedInputVariables")] 283 private string[] AllowedInputVariables { 284 get { 285 if (IsCompatibilityLoaded) return allowedInputVariables; 286 else return originalTrainingData.AllowedInputVariables.ToArray(); 287 } 288 set { allowedInputVariables = value; } 289 } 156 290 [Storable] 157 291 private int RandomForestBufSize { 158 292 get { 159 return randomForest.innerobj.bufsize; 293 if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize; 294 else return 0; 160 295 } 161 296 set { … … 166 301 private int RandomForestNClasses { 167 302 get { 168 return randomForest.innerobj.nclasses; 303 if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses; 304 else return 0; 169 305 } 170 306 set { … … 175 311 private int RandomForestNTrees { 176 312 get { 177 return randomForest.innerobj.ntrees; 313 if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees; 314 else return 0; 178 315 } 179 316 set { … … 184 321 private int RandomForestNVars { 185 322 get { 186 return randomForest.innerobj.nvars; 323 if (IsCompatibilityLoaded) return randomForest.innerobj.nvars; 324 else return 0; 187 325 } 188 326 set { … … 193 331 private double[] RandomForestTrees { 194 332 get { 195 return randomForest.innerobj.trees; 333 if (IsCompatibilityLoaded) return randomForest.innerobj.trees; 334 else return new double[] { }; 196 335 } 197 336 set { -
branches/Breadcrumbs/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs
r9456 r11594 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 3Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. … … 130 130 public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 131 131 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) { 132 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1."); 133 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1."); 134 135 alglib.math.rndobject = new System.Random(seed); 136 137 Dataset dataset = problemData.Dataset; 138 string targetVariable = problemData.TargetVariable; 139 IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables; 140 IEnumerable<int> rows = problemData.TrainingIndices; 141 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows); 142 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x))) 143 throw new NotSupportedException("Random forest regression does not support NaN or infinity values in the input dataset."); 144 145 int info = 0; 146 alglib.decisionforest dForest = new alglib.decisionforest(); 147 alglib.dfreport rep = new alglib.dfreport(); ; 148 int nRows = inputMatrix.GetLength(0); 149 int nColumns = inputMatrix.GetLength(1); 150 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1); 151 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); 152 153 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, 1, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); 154 if (info != 1) throw new ArgumentException("Error in calculation of random forest regression solution"); 155 156 rmsError = rep.rmserror; 157 avgRelError = rep.avgrelerror; 158 outOfBagAvgRelError = rep.oobavgrelerror; 159 outOfBagRmsError = rep.oobrmserror; 160 161 return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables)); 132 var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 133 return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), model); 162 134 } 163 135 #endregion -
branches/Breadcrumbs/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegressionSolution.cs
r9456 r11594 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 3Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. -
branches/Breadcrumbs/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs
r11400 r11594 30 30 using HeuristicLab.Core; 31 31 using HeuristicLab.Data; 32 using HeuristicLab.Parameters; 33 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 32 34 using HeuristicLab.Problems.DataAnalysis; 33 35 using HeuristicLab.Random; 34 36 35 37 namespace HeuristicLab.Algorithms.DataAnalysis { 36 public class RFParameter : ICloneable { 37 public double n; // number of trees 38 public double m; 39 public double r; 40 41 public object Clone() { return new RFParameter { n = this.n, m = this.m, r = this.r }; } 38 [Item("RFParameter", "A random forest parameter collection")] 39 [StorableClass] 40 public class RFParameter : ParameterCollection { 41 public RFParameter() { 42 base.Add(new FixedValueParameter<IntValue>("N", "The number of random forest trees", new IntValue(50))); 43 base.Add(new FixedValueParameter<DoubleValue>("M", "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.1))); 44 base.Add(new FixedValueParameter<DoubleValue>("R", "The ratio of the training set that will be used in the construction of individual trees (0<r<=1)", new DoubleValue(0.1))); 45 } 46 47 [StorableConstructor] 48 protected RFParameter(bool deserializing) 49 : base(deserializing) { 50 } 51 52 protected RFParameter(RFParameter original, Cloner cloner) 53 : base(original, cloner) { 54 this.N = original.N; 55 this.R = original.R; 56 this.M = original.M; 57 } 58 59 public override IDeepCloneable Clone(Cloner cloner) { 60 return new RFParameter(this, cloner); 61 } 62 63 private IFixedValueParameter<IntValue> NParameter { 64 get { return (IFixedValueParameter<IntValue>)base["N"]; } 65 } 66 67 private IFixedValueParameter<DoubleValue> RParameter { 68 get { return (IFixedValueParameter<DoubleValue>)base["R"]; } 69 } 70 71 private IFixedValueParameter<DoubleValue> MParameter { 72 get { return (IFixedValueParameter<DoubleValue>)base["M"]; } 73 } 74 75 public int N { 76 get { return NParameter.Value.Value; } 77 set { NParameter.Value.Value = value; } 78 } 79 80 public double R { 81 get { return RParameter.Value.Value; } 82 set { RParameter.Value.Value = value; } 83 } 84 85 public double M { 86 get { return MParameter.Value.Value; } 87 set { MParameter.Value.Value = value; } 88 } 42 89 } 43 90 44 91 public static class RandomForestUtil { 92 private static readonly object locker = new object(); 93 45 94 private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) { 46 95 avgTestMse = 0; … … 62 111 avgTestMse /= partitions.Length; 63 112 } 113 64 114 private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) { 65 115 avgTestAccuracy = 0; … … 91 141 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { 92 142 var parameterValues = parameterCombination.ToList(); 93 double testMSE;94 143 var parameters = new RFParameter(); 95 144 for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); } 96 145 double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError; 97 var model = RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed,98 out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError); 99 if (bestOutOfBagRmsError > outOfBagRmsError) {100 lock (bestParameters) {146 RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError); 147 148 lock (locker) { 149 if (bestOutOfBagRmsError > outOfBagRmsError) { 101 150 bestOutOfBagRmsError = outOfBagRmsError; 102 151 bestParameters = (RFParameter)parameters.Clone(); … … 119 168 for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); } 120 169 double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError; 121 var model = RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed,170 RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, 122 171 out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError); 123 if (bestOutOfBagRmsError > outOfBagRmsError) { 124 lock (bestParameters) { 172 173 lock (locker) { 174 if (bestOutOfBagRmsError > outOfBagRmsError) { 125 175 bestOutOfBagRmsError = outOfBagRmsError; 126 176 bestParameters = (RFParameter)parameters.Clone(); … … 133 183 public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 134 184 DoubleValue mse = new DoubleValue(Double.MaxValue); 135 RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };185 RFParameter bestParameter = new RFParameter(); 136 186 137 187 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); … … 146 196 setters[i](parameters, parameterValues[i]); 147 197 } 148 CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testMSE); 149 if (testMSE < mse.Value) { 150 lock (mse) { 198 CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testMSE); 199 200 lock (locker) { 201 if (testMSE < mse.Value) { 151 202 mse.Value = testMSE; 152 203 bestParameter = (RFParameter)parameters.Clone(); … … 159 210 public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 160 211 DoubleValue accuracy = new DoubleValue(0); 161 RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };212 RFParameter bestParameter = new RFParameter(); 162 213 163 214 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); … … 172 223 setters[i](parameters, parameterValues[i]); 173 224 } 174 CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testAccuracy); 175 if (testAccuracy > accuracy.Value) { 176 lock (accuracy) { 225 CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testAccuracy); 226 227 lock (locker) { 228 if (testAccuracy > accuracy.Value) { 177 229 accuracy.Value = testAccuracy; 178 230 bestParameter = (RFParameter)parameters.Clone(); … … 252 304 var targetExp = Expression.Parameter(typeof(RFParameter)); 253 305 var valueExp = Expression.Parameter(typeof(double)); 254 var fieldExp = Expression. Field(targetExp, field);306 var fieldExp = Expression.Property(targetExp, field); 255 307 var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type)); 256 308 var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
Note: See TracChangeset
for help on using the changeset viewer.