Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/08/14 17:33:53 (11 years ago)
Author:
gkronber
Message:

#1721 refactored RandormForestModel and changed persistence (store data and parameters instead of model)

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs

    r9456 r10321  
    130130    public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
    131131      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
    132       if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1.");
    133       if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1.");
    134 
    135       alglib.math.rndobject = new System.Random(seed);
    136 
    137       Dataset dataset = problemData.Dataset;
    138       string targetVariable = problemData.TargetVariable;
    139       IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
    140       IEnumerable<int> rows = problemData.TrainingIndices;
    141       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
    142       if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    143         throw new NotSupportedException("Random forest regression does not support NaN or infinity values in the input dataset.");
    144 
    145       int info = 0;
    146       alglib.decisionforest dForest = new alglib.decisionforest();
    147       alglib.dfreport rep = new alglib.dfreport(); ;
    148       int nRows = inputMatrix.GetLength(0);
    149       int nColumns = inputMatrix.GetLength(1);
    150       int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
    151       int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
    152 
    153       alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, 1, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
    154       if (info != 1) throw new ArgumentException("Error in calculation of random forest regression solution");
    155 
    156       rmsError = rep.rmserror;
    157       avgRelError = rep.avgrelerror;
    158       outOfBagAvgRelError = rep.oobavgrelerror;
    159       outOfBagRmsError = rep.oobrmserror;
    160 
    161       return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables));
     132      var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     133      return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), model);
    162134    }
    163135    #endregion
Note: See TracChangeset for help on using the changeset viewer.