Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/27/19 15:46:20 (5 years ago)
Author:
mkommend
Message:

#2952: Intermediate commit of refactoring RF models that is not yet finished.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r16763 r17045  
    2323using System.Collections.Generic;
    2424using System.Linq;
     25using HEAL.Attic;
    2526using HeuristicLab.Common;
    2627using HeuristicLab.Core;
    2728using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    28 using HEAL.Attic;
    2929using HeuristicLab.Problems.DataAnalysis;
    3030using HeuristicLab.Problems.DataAnalysis.Symbolic;
     
    3434  /// Represents a random forest model for regression and classification
    3535  /// </summary>
    36   [StorableType("A4F688CD-1F42-4103-8449-7DE52AEF6C69")]
     36  [Obsolete("This class only exists for backwards compatibility reasons. Use RFModelSurrogate or RFModelFull instead.")]
     37  [StorableType("9AA4CCC2-CD75-4471-8DF6-949E5B783642")]
    3738  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
    3839  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
     
    139140    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    140141      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
    141       AssertInputMatrix(inputData);
     142      RandomForestUtil.AssertInputMatrix(inputData);
    142143
    143144      int n = inputData.GetLength(0);
     
    157158    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
    158159      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
    159       AssertInputMatrix(inputData);
     160      RandomForestUtil.AssertInputMatrix(inputData);
    160161
    161162      int n = inputData.GetLength(0);
     
    175176    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
    176177      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
    177       AssertInputMatrix(inputData);
     178      RandomForestUtil.AssertInputMatrix(inputData);
    178179
    179180      int n = inputData.GetLength(0);
     
    315316
    316317      alglib.dfreport rep;
    317       var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
     318      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
    318319
    319320      rmsError = rep.rmserror;
     
    353354
    354355      alglib.dfreport rep;
    355       var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
     356      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
    356357
    357358      rmsError = rep.rmserror;
     
    361362
    362363      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
    363     }
    364 
    365     private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
    366       AssertParameters(r, m);
    367       AssertInputMatrix(inputMatrix);
    368 
    369       int info = 0;
    370       alglib.math.rndobject = new System.Random(seed);
    371       var dForest = new alglib.decisionforest();
    372       rep = new alglib.dfreport();
    373       int nRows = inputMatrix.GetLength(0);
    374       int nColumns = inputMatrix.GetLength(1);
    375       int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
    376       int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
    377 
    378       alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
    379       if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
    380       return dForest;
    381     }
    382 
    383     private static void AssertParameters(double r, double m) {
    384       if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
    385       if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
    386     }
    387 
    388     private static void AssertInputMatrix(double[,] inputMatrix) {
    389       if (inputMatrix.ContainsNanOrInfinity())
    390         throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
    391364    }
    392365
Note: See TracChangeset for help on using the changeset viewer.