Changeset 17045 for branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs
- Timestamp:
- 06/27/19 15:46:20 (6 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs
r16565 r17045 27 27 using System.Linq.Expressions; 28 28 using System.Threading.Tasks; 29 using HEAL.Attic; 29 30 using HeuristicLab.Common; 30 31 using HeuristicLab.Core; 31 32 using HeuristicLab.Data; 32 33 using HeuristicLab.Parameters; 33 using HEAL.Attic;34 34 using HeuristicLab.Problems.DataAnalysis; 35 35 using HeuristicLab.Random; … … 89 89 90 90 public static class RandomForestUtil { 91 public static void AssertParameters(double r, double m) { 92 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1."); 93 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1."); 94 } 95 96 public static void AssertInputMatrix(double[,] inputMatrix) { 97 if (inputMatrix.ContainsNanOrInfinity()) 98 throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset."); 99 } 100 101 internal static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) { 102 RandomForestUtil.AssertParameters(r, m); 103 RandomForestUtil.AssertInputMatrix(inputMatrix); 104 105 int info = 0; 106 alglib.math.rndobject = new System.Random(seed); 107 var dForest = new alglib.decisionforest(); 108 rep = new alglib.dfreport(); 109 int nRows = inputMatrix.GetLength(0); 110 int nColumns = inputMatrix.GetLength(1); 111 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1); 112 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); 113 114 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); 115 if (info != 1) throw new ArgumentException("Error in calculation of random forest model"); 116 return dForest; 117 } 118 119 91 120 private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) { 92 121 avgTestMse = 0;
Note: See TracChangeset
for help on using the changeset viewer.