Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/27/19 15:46:20 (5 years ago)
Author:
mkommend
Message:

#2952: Intermediate commit of refactoring RF models that is not yet finished.

Location:
branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4
Files:
2 added
8 edited

Legend:

Unmodified
Added
Removed
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs

    r17030 r17045  
    269269
    270270        if (ModelCreation == ModelCreation.SurrogateModel) {
    271           model = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction, Iterations, MaxSize, R, M, Nu, (GradientBoostedTreesModel)model);
     271          model = new GradientBoostedTreesModelSurrogate((GradientBoostedTreesModel)model, problemData, (uint)Seed, lossFunction, Iterations, MaxSize, R, M, Nu);
    272272        }
    273273
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs

    r17030 r17045  
    3434  // this is essentially a collection of weighted regression models
    3535  public sealed class GradientBoostedTreesModel : RegressionModel, IGradientBoostedTreesModel {
    36     // BackwardsCompatibility3.4 for allowing deserialization & serialization of old models
    37     #region Backwards compatible code, remove with 3.5
    38 
    3936    [Storable(Name = "models")]
    4037    private IList<IRegressionModel> __persistedModels {
     
    5350      get { return weights; }
    5451    }
    55     #endregion
    5652
    5753    public override IEnumerable<string> VariablesUsedForPrediction {
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs

    r16565 r17045  
    9090    private Func<IGradientBoostedTreesModel> CreateLazyInitFunc(IGradientBoostedTreesModel clonedModel) {
    9191      return () => {
    92         return clonedModel == null ? RecalculateModel() : clonedModel;
     92        return clonedModel ?? RecalculateModel();
    9393      };
    9494    }
    9595
    9696    // create only the surrogate model without an actual model
    97     public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed,
     97    private GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed,
    9898      ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu)
    9999      : base(trainingProblemData.TargetVariable, "Gradient boosted tree model", string.Empty) {
     
    106106      this.m = m;
    107107      this.nu = nu;
     108
     109      actualModel = new Lazy<IGradientBoostedTreesModel>(() => RecalculateModel());
    108110    }
    109111
    110112    // wrap an actual model in a surrograte
    111     public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed,
    112       ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu,
    113       IGradientBoostedTreesModel model)
     113    public GradientBoostedTreesModelSurrogate(IGradientBoostedTreesModel model, IRegressionProblemData trainingProblemData, uint seed,
     114      ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu)
    114115      : this(trainingProblemData, seed, lossFunction, iterations, maxSize, r, m, nu) {
    115116      actualModel = new Lazy<IGradientBoostedTreesModel>(() => model);
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r17042 r17045  
    394394    <Compile Include="RandomForest\RandomForestClassification.cs" />
    395395    <Compile Include="RandomForest\RandomForestModel.cs" />
     396    <Compile Include="RandomForest\RandomForestModelFull.cs" />
    396397    <Compile Include="RandomForest\RandomForestRegression.cs" />
    397398    <Compile Include="RandomForest\RandomForestRegressionSolution.cs" />
     399    <Compile Include="RandomForest\RandomForestModelSurrogate.cs" />
    398400    <Compile Include="RandomForest\RandomForestUtil.cs" />
    399401    <Compile Include="SupportVectorMachine\SupportVectorClassification.cs" />
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs

    r16565 r17045  
    2121
    2222using System.Threading;
     23using HEAL.Attic;
    2324using HeuristicLab.Common;
    2425using HeuristicLab.Core;
     
    2627using HeuristicLab.Optimization;
    2728using HeuristicLab.Parameters;
    28 using HEAL.Attic;
    2929using HeuristicLab.Problems.DataAnalysis;
    3030
     
    144144
    145145      if (CreateSolution) {
    146         var solution = new RandomForestClassificationSolution(model, (IClassificationProblemData)Problem.ProblemData.Clone());
     146        var solution = model.CreateClassificationSolution(Problem.ProblemData);
    147147        Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
    148148      }
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r16763 r17045  
    2323using System.Collections.Generic;
    2424using System.Linq;
     25using HEAL.Attic;
    2526using HeuristicLab.Common;
    2627using HeuristicLab.Core;
    2728using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    28 using HEAL.Attic;
    2929using HeuristicLab.Problems.DataAnalysis;
    3030using HeuristicLab.Problems.DataAnalysis.Symbolic;
     
    3434  /// Represents a random forest model for regression and classification
    3535  /// </summary>
    36   [StorableType("A4F688CD-1F42-4103-8449-7DE52AEF6C69")]
     36  [Obsolete("This class only exists for backwards compatibility reasons. Use RFModelSurrogate or RFModelFull instead.")]
     37  [StorableType("9AA4CCC2-CD75-4471-8DF6-949E5B783642")]
    3738  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
    3839  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
     
    139140    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    140141      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
    141       AssertInputMatrix(inputData);
     142      RandomForestUtil.AssertInputMatrix(inputData);
    142143
    143144      int n = inputData.GetLength(0);
     
    157158    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
    158159      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
    159       AssertInputMatrix(inputData);
     160      RandomForestUtil.AssertInputMatrix(inputData);
    160161
    161162      int n = inputData.GetLength(0);
     
    175176    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
    176177      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
    177       AssertInputMatrix(inputData);
     178      RandomForestUtil.AssertInputMatrix(inputData);
    178179
    179180      int n = inputData.GetLength(0);
     
    315316
    316317      alglib.dfreport rep;
    317       var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
     318      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
    318319
    319320      rmsError = rep.rmserror;
     
    353354
    354355      alglib.dfreport rep;
    355       var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
     356      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
    356357
    357358      rmsError = rep.rmserror;
     
    361362
    362363      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
    363     }
    364 
    365     private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
    366       AssertParameters(r, m);
    367       AssertInputMatrix(inputMatrix);
    368 
    369       int info = 0;
    370       alglib.math.rndobject = new System.Random(seed);
    371       var dForest = new alglib.decisionforest();
    372       rep = new alglib.dfreport();
    373       int nRows = inputMatrix.GetLength(0);
    374       int nColumns = inputMatrix.GetLength(1);
    375       int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
    376       int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
    377 
    378       alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
    379       if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
    380       return dForest;
    381     }
    382 
    383     private static void AssertParameters(double r, double m) {
    384       if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
    385       if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
    386     }
    387 
    388     private static void AssertInputMatrix(double[,] inputMatrix) {
    389       if (inputMatrix.ContainsNanOrInfinity())
    390         throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
    391364    }
    392365
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs

    r16565 r17045  
    2020#endregion
    2121
     22using System.Collections.Generic;
     23using System.Linq;
    2224using System.Threading;
     25using HEAL.Attic;
    2326using HeuristicLab.Common;
    2427using HeuristicLab.Core;
     
    2629using HeuristicLab.Optimization;
    2730using HeuristicLab.Parameters;
    28 using HEAL.Attic;
    2931using HeuristicLab.Problems.DataAnalysis;
    3032
     
    144146
    145147      if (CreateSolution) {
    146         var solution = new RandomForestRegressionSolution(model, (IRegressionProblemData)Problem.ProblemData.Clone());
     148        var solution = model.CreateRegressionSolution(Problem.ProblemData);
    147149        Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution));
    148150      }
    149151    }
     152
    150153
    151154    // keep for compatibility with old API
     
    157160    }
    158161
    159     public static RandomForestModel CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees,
    160       double r, double m, int seed,
    161       out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
    162       return RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed,
    163         rmsError: out rmsError, avgRelError: out avgRelError, outOfBagRmsError: out outOfBagRmsError, outOfBagAvgRelError: out outOfBagAvgRelError);
     162    public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees,
     163     double r, double m, int seed,
     164     out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
     165      return CreateRandomForestRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     166    }
     167
     168    public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
     169    out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
     170
     171      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
     172      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
     173
     174      alglib.dfreport rep;
     175      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
     176
     177      rmsError = rep.rmserror;
     178      outOfBagRmsError = rep.oobrmserror;
     179      avgRelError = rep.avgrelerror;
     180      outOfBagAvgRelError = rep.oobavgrelerror;
     181
     182      return new RandomForestModelFull(dForest, problemData.TargetVariable, problemData.AllowedInputVariables);
     183
     184      //return RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed,
     185      //rmsError: out rmsError, avgRelError: out avgRelError, outOfBagRmsError: out outOfBagRmsError, outOfBagAvgRelError: out outOfBagAvgRelError);
    164186    }
    165187
  • branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs

    r16565 r17045  
    2727using System.Linq.Expressions;
    2828using System.Threading.Tasks;
     29using HEAL.Attic;
    2930using HeuristicLab.Common;
    3031using HeuristicLab.Core;
    3132using HeuristicLab.Data;
    3233using HeuristicLab.Parameters;
    33 using HEAL.Attic;
    3434using HeuristicLab.Problems.DataAnalysis;
    3535using HeuristicLab.Random;
     
    8989
    9090  public static class RandomForestUtil {
     91    public static void AssertParameters(double r, double m) {
     92      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
     93      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
     94    }
     95
     96    public static void AssertInputMatrix(double[,] inputMatrix) {
     97      if (inputMatrix.ContainsNanOrInfinity())
     98        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
     99    }
     100
     101    internal static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
     102      RandomForestUtil.AssertParameters(r, m);
     103      RandomForestUtil.AssertInputMatrix(inputMatrix);
     104
     105      int info = 0;
     106      alglib.math.rndobject = new System.Random(seed);
     107      var dForest = new alglib.decisionforest();
     108      rep = new alglib.dfreport();
     109      int nRows = inputMatrix.GetLength(0);
     110      int nColumns = inputMatrix.GetLength(1);
     111      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
     112      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
     113
     114      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
     115      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
     116      return dForest;
     117    }
     118
     119
    91120    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
    92121      avgTestMse = 0;
Note: See TracChangeset for help on using the changeset viewer.