Changeset 11443 for trunk/sources
- Timestamp:
- 10/10/14 13:58:19 (10 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs
r11426 r11443 30 30 using HeuristicLab.Core; 31 31 using HeuristicLab.Data; 32 using HeuristicLab.Parameters; 33 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 32 34 using HeuristicLab.Problems.DataAnalysis; 33 35 using HeuristicLab.Random; 34 36 35 37 namespace HeuristicLab.Algorithms.DataAnalysis { 36 public class RFParameter : ICloneable { 37 public double n; // number of trees 38 public double m; 39 public double r; 40 41 public object Clone() { return new RFParameter { n = this.n, m = this.m, r = this.r }; } 38 [Item("RFParameter", "A random forest parameter collection")] 39 [StorableClass] 40 public class RFParameter : ParameterCollection { 41 public RFParameter() { 42 base.Add(new FixedValueParameter<IntValue>("N", "The number of random forest trees", new IntValue(50))); 43 base.Add(new FixedValueParameter<DoubleValue>("M", "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.1))); 44 base.Add(new FixedValueParameter<DoubleValue>("R", "The ratio of the training set that will be used in the construction of individual trees (0<r<=1)", new DoubleValue(0.1))); 45 } 46 47 [StorableConstructor] 48 private RFParameter(bool deserializing) 49 : base(deserializing) { 50 } 51 52 private RFParameter(RFParameter original, Cloner cloner) 53 : base(original, cloner) { 54 this.N = original.N; 55 this.R = original.R; 56 this.M = original.M; 57 } 58 59 public override IDeepCloneable Clone(Cloner cloner) { 60 return new RFParameter(this, cloner); 61 } 62 63 private IFixedValueParameter<IntValue> NParameter { 64 get { return (IFixedValueParameter<IntValue>)base["N"]; } 65 } 66 67 private IFixedValueParameter<DoubleValue> RParameter { 68 get { return (IFixedValueParameter<DoubleValue>)base["R"]; } 69 } 70 71 private IFixedValueParameter<DoubleValue> MParameter { 72 get { return (IFixedValueParameter<DoubleValue>)base["M"]; } 73 } 74 75 public int N { 76 get { return NParameter.Value.Value; } 77 set { NParameter.Value.Value = value; } 78 } 79 80 public double R { 81 get { return RParameter.Value.Value; } 82 set { RParameter.Value.Value = value; } 83 } 84 85 public double M { 86 get { return MParameter.Value.Value; } 87 set { MParameter.Value.Value = value; } 88 } 42 89 } 43 90 … … 64 111 avgTestMse /= partitions.Length; 65 112 } 113 66 114 private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) { 67 115 avgTestAccuracy = 0; … … 96 144 for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); } 97 145 double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError; 98 RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);146 RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError); 99 147 100 148 lock (locker) { … … 120 168 for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); } 121 169 double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError; 122 RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed,170 RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, 123 171 out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError); 124 172 … … 135 183 public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 136 184 DoubleValue mse = new DoubleValue(Double.MaxValue); 137 RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };185 RFParameter bestParameter = new RFParameter(); 138 186 139 187 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); … … 148 196 setters[i](parameters, parameterValues[i]); 149 197 } 150 CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testMSE);198 CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testMSE); 151 199 152 200 lock (locker) { … … 162 210 public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 163 211 DoubleValue accuracy = new DoubleValue(0); 164 RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };212 RFParameter bestParameter = new RFParameter(); 165 213 166 214 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); … … 175 223 setters[i](parameters, parameterValues[i]); 176 224 } 177 CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testAccuracy);225 CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testAccuracy); 178 226 179 227 lock (locker) { … … 256 304 var targetExp = Expression.Parameter(typeof(RFParameter)); 257 305 var valueExp = Expression.Parameter(typeof(double)); 258 var fieldExp = Expression. Field(targetExp, field);306 var fieldExp = Expression.Property(targetExp, field); 259 307 var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type)); 260 308 var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
Note: See TracChangeset
for help on using the changeset viewer.