Changeset 10321 for branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs
- Timestamp:
- 01/08/14 17:33:53 (11 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs
r9456 r10321 130 130 public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 131 131 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) { 132 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1."); 133 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1."); 134 135 alglib.math.rndobject = new System.Random(seed); 136 137 Dataset dataset = problemData.Dataset; 138 string targetVariable = problemData.TargetVariable; 139 IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables; 140 IEnumerable<int> rows = problemData.TrainingIndices; 141 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows); 142 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x))) 143 throw new NotSupportedException("Random forest regression does not support NaN or infinity values in the input dataset."); 144 145 int info = 0; 146 alglib.decisionforest dForest = new alglib.decisionforest(); 147 alglib.dfreport rep = new alglib.dfreport(); ; 148 int nRows = inputMatrix.GetLength(0); 149 int nColumns = inputMatrix.GetLength(1); 150 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1); 151 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); 152 153 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, 1, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); 154 if (info != 1) throw new ArgumentException("Error in calculation of random forest regression solution"); 155 156 rmsError = rep.rmserror; 157 avgRelError = rep.avgrelerror; 158 outOfBagAvgRelError = rep.oobavgrelerror; 159 outOfBagRmsError = rep.oobrmserror; 160 161 return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables)); 132 var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 133 return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), model); 162 134 } 163 135 #endregion
Note: See TracChangeset
for help on using the changeset viewer.