Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/23/19 20:22:52 (5 years ago)
Author:
gkronber
Message:

#2952: merged r17154 from trunk to stable

Location:
stable
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Algorithms.DataAnalysis

  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4

  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs

    r17097 r17157  
    2727using System.Linq.Expressions;
    2828using System.Threading.Tasks;
     29using HEAL.Attic;
    2930using HeuristicLab.Common;
    3031using HeuristicLab.Core;
    3132using HeuristicLab.Data;
    3233using HeuristicLab.Parameters;
    33 using HEAL.Attic;
    3434using HeuristicLab.Problems.DataAnalysis;
    3535using HeuristicLab.Random;
     
    8989
    9090  public static class RandomForestUtil {
     91    public static void AssertParameters(double r, double m) {
     92      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
     93      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
     94    }
     95
     96    public static void AssertInputMatrix(double[,] inputMatrix) {
     97      if (inputMatrix.ContainsNanOrInfinity())
     98        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
     99    }
     100
     101    internal static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
     102      RandomForestUtil.AssertParameters(r, m);
     103      RandomForestUtil.AssertInputMatrix(inputMatrix);
     104
     105      int info = 0;
     106      alglib.math.rndobject = new System.Random(seed);
     107      var dForest = new alglib.decisionforest();
     108      rep = new alglib.dfreport();
     109      int nRows = inputMatrix.GetLength(0);
     110      int nColumns = inputMatrix.GetLength(1);
     111      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
     112      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
     113
     114      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
     115      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
     116      return dForest;
     117    }
     118
     119
    91120    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
    92121      avgTestMse = 0;
Note: See TracChangeset for help on using the changeset viewer.