Changeset 18027 for branches/3026_IntegrationIntoSymSpace/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs
- Timestamp:
- 07/20/21 18:13:55 (3 years ago)
- Location:
- branches/3026_IntegrationIntoSymSpace
- Files:
-
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/3026_IntegrationIntoSymSpace
-
branches/3026_IntegrationIntoSymSpace/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
/trunk/HeuristicLab.Algorithms.DataAnalysis merged: 17931,17934,17942,17979
- Property svn:mergeinfo changed
-
branches/3026_IntegrationIntoSymSpace/HeuristicLab.Algorithms.DataAnalysis/3.4
- Property svn:mergeinfo changed
/trunk/HeuristicLab.Algorithms.DataAnalysis/3.4 merged: 17931,17934,17942,17979
- Property svn:mergeinfo changed
-
branches/3026_IntegrationIntoSymSpace/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs
r17180 r18027 22 22 #endregion 23 23 24 extern alias alglib_3_7; 25 24 26 using System; 25 27 using System.Collections.Generic; … … 95 97 96 98 public static void AssertInputMatrix(double[,] inputMatrix) { 97 if (inputMatrix.ContainsNanOrInfinity())98 throw new NotSupportedException("Random forest modeling does not support NaN or infinityvalues in the input dataset.");99 foreach(var val in inputMatrix) if(double.IsNaN(val)) 100 throw new NotSupportedException("Random forest modeling does not support NaN values in the input dataset."); 99 101 } 100 102 … … 103 105 RandomForestUtil.AssertInputMatrix(inputMatrix); 104 106 107 int nRows = inputMatrix.GetLength(0); 108 int nColumns = inputMatrix.GetLength(1); 109 110 alglib.dfbuildercreate(out var dfbuilder); 111 alglib.dfbuildersetdataset(dfbuilder, inputMatrix, nRows, nColumns - 1, nClasses); 112 alglib.dfbuildersetimportancenone(dfbuilder); // do not calculate importance (TODO add this feature) 113 alglib.dfbuildersetrdfalgo(dfbuilder, 0); // only one algorithm supported in version 3.17 114 alglib.dfbuildersetrdfsplitstrength(dfbuilder, 2); // 0 = split at the random position, fastest one 115 // 1 = split at the middle of the range 116 // 2 = strong split at the best point of the range (default) 117 alglib.dfbuildersetrndvarsratio(dfbuilder, m); 118 alglib.dfbuildersetsubsampleratio(dfbuilder, r); 119 alglib.dfbuildersetseed(dfbuilder, seed); 120 alglib.dfbuilderbuildrandomforest(dfbuilder, nTrees, out var dForest, out rep); 121 return dForest; 122 } 123 internal static alglib_3_7.alglib.decisionforest CreateRandomForestModelAlglib_3_7(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib_3_7.alglib.dfreport rep) { 124 RandomForestUtil.AssertParameters(r, m); 125 RandomForestUtil.AssertInputMatrix(inputMatrix); 126 105 127 int info = 0; 106 alglib .math.rndobject = new System.Random(seed);107 var dForest = new alglib .decisionforest();108 rep = new alglib .dfreport();128 alglib_3_7.alglib.math.rndobject = new System.Random(seed); 129 var dForest = new alglib_3_7.alglib.decisionforest(); 130 rep = new alglib_3_7.alglib.dfreport(); 109 131 int nRows = inputMatrix.GetLength(0); 110 132 int nColumns = inputMatrix.GetLength(1); … … 112 134 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); 113 135 114 alglib .dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits +alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);136 alglib_3_7.alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib_3_7.alglib.dforest.dfusestrongsplits + alglib_3_7.alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); 115 137 if (info != 1) throw new ArgumentException("Error in calculation of random forest model"); 116 138 return dForest; … … 123 145 var targetVariable = GetTargetVariableName(problemData); 124 146 foreach (var tuple in partitions) { 125 double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;126 147 var trainingRandomForestPartition = tuple.Item1; 127 148 var testRandomForestPartition = tuple.Item2; 128 var model = RandomForestModel.CreateRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 149 var model = RandomForestRegression.CreateRandomForestRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, 150 out var rmsError, out var avgRelError, out var outOfBagRmsError, out var outOfBagAvgRelError); 129 151 var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition); 130 152 var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition); … … 143 165 var targetVariable = GetTargetVariableName(problemData); 144 166 foreach (var tuple in partitions) { 145 double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;146 167 var trainingRandomForestPartition = tuple.Item1; 147 168 var testRandomForestPartition = tuple.Item2; 148 var model = RandomForestModel.CreateClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 169 var model = RandomForestClassification.CreateRandomForestClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, 170 out var rmsError, out var avgRelError, out var outOfBagRmsError, out var outOfBagAvgRelError); 149 171 var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition); 150 172 var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition); … … 176 198 var parameters = new RFParameter(); 177 199 for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); } 178 double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;179 RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, out rmsError, out outOfBagRmsError, out avgRelError, outoutOfBagAvgRelError);200 RandomForestRegression.CreateRandomForestRegressionModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, 201 out var rmsError, out var avgRelError, out var outOfBagRmsError, out var outOfBagAvgRelError); 180 202 181 203 lock (locker) { … … 208 230 var parameters = new RFParameter(); 209 231 for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); } 210 double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError; 211 RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, 212 out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError); 232 RandomForestClassification.CreateRandomForestClassificationModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, 233 out var rmsError, out var avgRelError, out var outOfBagRmsError, out var outOfBagAvgRelError); 213 234 214 235 lock (locker) { … … 227 248 /// <param name="problemData">The regression problem data</param> 228 249 /// <param name="numberOfFolds">The number of folds for crossvalidation</param> 229 /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param>230 250 /// <param name="parameterRanges">The ranges for each parameter in the grid search</param> 231 251 /// <param name="seed">The random seed (required by the random forest model)</param> 232 252 /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param> 233 253 /// <returns>The best parameter values found by the grid search</returns> 234 public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds,Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {254 public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 235 255 DoubleValue mse = new DoubleValue(Double.MaxValue); 236 256 RFParameter bestParameter = new RFParameter();
Note: See TracChangeset
for help on using the changeset viewer.