Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/11/14 10:25:04 (10 years ago)
Author:
gkronber
Message:

#1721 merged r10321:10322 from feature branch to the trunk

Location:
trunk/sources
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources

  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis

  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs

    r9456 r10963  
    131131    public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    132132      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
    133       if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1.");
    134       if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1.");
    135 
    136       alglib.math.rndobject = new System.Random(seed);
    137 
    138       Dataset dataset = problemData.Dataset;
    139       string targetVariable = problemData.TargetVariable;
    140       IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
    141       IEnumerable<int> rows = problemData.TrainingIndices;
    142       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
    143       if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    144         throw new NotSupportedException("Random forest classification does not support NaN or infinity values in the input dataset.");
    145 
    146       int info = 0;
    147       alglib.decisionforest dForest = new alglib.decisionforest();
    148       alglib.dfreport rep = new alglib.dfreport(); ;
    149       int nRows = inputMatrix.GetLength(0);
    150       int nColumns = inputMatrix.GetLength(1);
    151       int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
    152       int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
    153 
    154 
    155       double[] classValues = problemData.ClassValues.ToArray();
    156       int nClasses = problemData.Classes;
    157       // map original class values to values [0..nClasses-1]
    158       Dictionary<double, double> classIndices = new Dictionary<double, double>();
    159       for (int i = 0; i < nClasses; i++) {
    160         classIndices[classValues[i]] = i;
    161       }
    162       for (int row = 0; row < nRows; row++) {
    163         inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
    164       }
    165       // execute random forest algorithm     
    166       alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
    167       if (info != 1) throw new ArgumentException("Error in calculation of random forest classification solution");
    168 
    169       rmsError = rep.rmserror;
    170       outOfBagRmsError = rep.oobrmserror;
    171       relClassificationError = rep.relclserror;
    172       outOfBagRelClassificationError = rep.oobrelclserror;
    173       return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables, classValues));
     133      var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
     134      return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model);
    174135    }
    175136    #endregion
Note: See TracChangeset for help on using the changeset viewer.